You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ju...@apache.org on 2018/02/05 18:09:11 UTC

[kafka] branch trunk updated: KAFKA-6254; Incremental fetch requests

This is an automated email from the ASF dual-hosted git repository.

junrao pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 7fe1c2b  KAFKA-6254; Incremental fetch requests
7fe1c2b is described below

commit 7fe1c2b3d3a78ea3ffb9e269563653626861fbd2
Author: Colin P. Mccabe <cm...@confluent.io>
AuthorDate: Mon Feb 5 10:09:17 2018 -0800

    KAFKA-6254; Incremental fetch requests
    
    Author: Colin P. Mccabe <cm...@confluent.io>
    
    Reviewers: Jason Gustafson <ja...@confluent.io>, Ismael Juma <is...@juma.me.uk>, Jun Rao <ju...@gmail.com>
    
    Closes #4418 from cmccabe/KAFKA-6254
---
 .../apache/kafka/clients/FetchSessionHandler.java  | 443 +++++++++++++
 .../kafka/clients/consumer/internals/Fetcher.java  |  88 +--
 .../errors/FetchSessionIdNotFoundException.java    |  29 +
 .../errors/InvalidFetchSessionEpochException.java  |  29 +
 .../org/apache/kafka/common/protocol/Errors.java   |  16 +
 .../apache/kafka/common/protocol/types/Struct.java |   6 +
 .../kafka/common/requests/FetchMetadata.java       | 154 +++++
 .../apache/kafka/common/requests/FetchRequest.java | 187 +++++-
 .../kafka/common/requests/FetchResponse.java       |  79 ++-
 .../kafka/common/utils/ImplicitLinkedHashSet.java  | 354 ++++++++++
 .../kafka/clients/FetchSessionHandlerTest.java     | 356 ++++++++++
 .../kafka/clients/consumer/KafkaConsumerTest.java  |   3 +-
 .../clients/consumer/internals/FetcherTest.java    | 175 +++--
 .../kafka/common/requests/RequestResponseTest.java |  70 +-
 .../common/utils/ImplicitLinkedHashSetTest.java    | 239 +++++++
 core/src/main/scala/kafka/api/ApiVersion.scala     |   7 +-
 .../src/main/scala/kafka/server/FetchSession.scala | 720 +++++++++++++++++++++
 core/src/main/scala/kafka/server/KafkaApis.scala   | 144 +++--
 core/src/main/scala/kafka/server/KafkaConfig.scala |  15 +
 core/src/main/scala/kafka/server/KafkaServer.scala |   7 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  |  50 +-
 .../scala/unit/kafka/server/FetchRequestTest.scala |  55 ++
 .../scala/unit/kafka/server/FetchSessionTest.scala | 312 +++++++++
 .../scala/unit/kafka/server/KafkaApisTest.scala    |   2 +
 .../util/ReplicaFetcherMockBlockingSend.scala      |   7 +-
 25 files changed, 3329 insertions(+), 218 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
new file mode 100644
index 0000000..195324e
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
@@ -0,0 +1,443 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.clients;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.FetchMetadata;
+import org.apache.kafka.common.requests.FetchRequest.PartitionData;
+import org.apache.kafka.common.requests.FetchResponse;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Utils;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+
+import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
+
+/**
+ * FetchSessionHandler maintains the fetch session state for connecting to a broker.
+ *
+ * Using the protocol outlined by KIP-227, clients can create incremental fetch sessions.
+ * These sessions allow the client to fetch information about a set of partition over
+ * and over, without explicitly enumerating all the partitions in the request and the
+ * response.
+ *
+ * FetchSessionHandler tracks the partitions which are in the session.  It also
+ * determines which partitions need to be included in each fetch request, and what
+ * the attached fetch session metadata should be for each request.  The corresponding
+ * class on the receiving broker side is FetchManager.
+ */
+public class FetchSessionHandler {
+    private final Logger log;
+
+    private final int node;
+
+    /**
+     * The metadata for the next fetch request.
+     */
+    private FetchMetadata nextMetadata = FetchMetadata.INITIAL;
+
+    public FetchSessionHandler(LogContext logContext, int node) {
+        this.log = logContext.logger(FetchSessionHandler.class);
+        this.node = node;
+    }
+
+    /**
+     * All of the partitions which exist in the fetch request session.
+     */
+    private LinkedHashMap<TopicPartition, PartitionData> sessionPartitions =
+        new LinkedHashMap<>(0);
+
+    public static class FetchRequestData {
+        /**
+         * The partitions to send in the fetch request.
+         */
+        private final Map<TopicPartition, PartitionData> toSend;
+
+        /**
+         * The partitions to send in the request's "forget" list.
+         */
+        private final List<TopicPartition> toForget;
+
+        /**
+         * All of the partitions which exist in the fetch request session.
+         */
+        private final Map<TopicPartition, PartitionData> sessionPartitions;
+
+        /**
+         * The metadata to use in this fetch request.
+         */
+        private final FetchMetadata metadata;
+
+        FetchRequestData(Map<TopicPartition, PartitionData> toSend,
+                         List<TopicPartition> toForget,
+                         Map<TopicPartition, PartitionData> sessionPartitions,
+                         FetchMetadata metadata) {
+            this.toSend = toSend;
+            this.toForget = toForget;
+            this.sessionPartitions = sessionPartitions;
+            this.metadata = metadata;
+        }
+
+        /**
+         * Get the set of partitions to send in this fetch request.
+         */
+        public Map<TopicPartition, PartitionData> toSend() {
+            return toSend;
+        }
+
+        /**
+         * Get a list of partitions to forget in this fetch request.
+         */
+        public List<TopicPartition> toForget() {
+            return toForget;
+        }
+
+        /**
+         * Get the full set of partitions involved in this fetch request.
+         */
+        public Map<TopicPartition, PartitionData> sessionPartitions() {
+            return sessionPartitions;
+        }
+
+        public FetchMetadata metadata() {
+            return metadata;
+        }
+
+        @Override
+        public String toString() {
+            if (metadata.isFull()) {
+                StringBuilder bld = new StringBuilder("FullFetchRequest(");
+                String prefix = "";
+                for (TopicPartition partition : toSend.keySet()) {
+                    bld.append(prefix);
+                    bld.append(partition);
+                    prefix = ", ";
+                }
+                bld.append(")");
+                return bld.toString();
+            } else {
+                StringBuilder bld = new StringBuilder("IncrementalFetchRequest(toSend=(");
+                String prefix = "";
+                for (TopicPartition partition : toSend.keySet()) {
+                    bld.append(prefix);
+                    bld.append(partition);
+                    prefix = ", ";
+                }
+                bld.append("), toForget=(");
+                prefix = "";
+                for (TopicPartition partition : toForget) {
+                    bld.append(prefix);
+                    bld.append(partition);
+                    prefix = ", ";
+                }
+                bld.append("), implied=(");
+                prefix = "";
+                for (TopicPartition partition : sessionPartitions.keySet()) {
+                    if (!toSend.containsKey(partition)) {
+                        bld.append(prefix);
+                        bld.append(partition);
+                        prefix = ", ";
+                    }
+                }
+                bld.append("))");
+                return bld.toString();
+            }
+        }
+    }
+
+    public class Builder {
+        /**
+         * The next partitions which we want to fetch.
+         *
+         * It is important to maintain the insertion order of this list by using a LinkedHashMap rather
+         * than a regular Map.
+         *
+         * One reason is that when dealing with FULL fetch requests, if there is not enough response
+         * space to return data from all partitions, the server will only return data from partitions
+         * early in this list.
+         *
+         * Another reason is because we make use of the list ordering to optimize the preparation of
+         * incremental fetch requests (see below).
+         */
+        private LinkedHashMap<TopicPartition, PartitionData> next = new LinkedHashMap<>();
+
+        /**
+         * Mark that we want data from this partition in the upcoming fetch.
+         */
+        public void add(TopicPartition topicPartition, PartitionData data) {
+            next.put(topicPartition, data);
+        }
+
+        public FetchRequestData build() {
+            if (nextMetadata.isFull()) {
+                log.debug("Built full fetch {} for node {} with {}.",
+                    nextMetadata, node, partitionsToLogString(next.keySet()));
+                sessionPartitions = next;
+                next = null;
+                Map<TopicPartition, PartitionData> toSend =
+                    Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions));
+                return new FetchRequestData(toSend, Collections.<TopicPartition>emptyList(), toSend, nextMetadata);
+            }
+
+            List<TopicPartition> added = new ArrayList<>();
+            List<TopicPartition> removed = new ArrayList<>();
+            List<TopicPartition> altered = new ArrayList<>();
+            for (Iterator<Entry<TopicPartition, PartitionData>> iter =
+                     sessionPartitions.entrySet().iterator(); iter.hasNext(); ) {
+                Entry<TopicPartition, PartitionData> entry = iter.next();
+                TopicPartition topicPartition = entry.getKey();
+                PartitionData prevData = entry.getValue();
+                PartitionData nextData = next.get(topicPartition);
+                if (nextData != null) {
+                    if (prevData.equals(nextData)) {
+                        // Omit this partition from the FetchRequest, because it hasn't changed
+                        // since the previous request.
+                        next.remove(topicPartition);
+                    } else {
+                        // Move the altered partition to the end of 'next'
+                        next.remove(topicPartition);
+                        next.put(topicPartition, nextData);
+                        entry.setValue(nextData);
+                        altered.add(topicPartition);
+                    }
+                } else {
+                    // Remove this partition from the session.
+                    iter.remove();
+                    // Indicate that we no longer want to listen to this partition.
+                    removed.add(topicPartition);
+                }
+            }
+            // Add any new partitions to the session.
+            for (Iterator<Entry<TopicPartition, PartitionData>> iter =
+                     next.entrySet().iterator(); iter.hasNext(); ) {
+                Entry<TopicPartition, PartitionData> entry = iter.next();
+                TopicPartition topicPartition = entry.getKey();
+                PartitionData nextData = entry.getValue();
+                if (sessionPartitions.containsKey(topicPartition)) {
+                    // In the previous loop, all the partitions which existed in both sessionPartitions
+                    // and next were moved to the end of next, or removed from next.  Therefore,
+                    // once we hit one of them, we know there are no more unseen entries to look
+                    // at in next.
+                    break;
+                }
+                sessionPartitions.put(topicPartition, nextData);
+                added.add(topicPartition);
+            }
+            log.debug("Built incremental fetch {} for node {}. Added {}, altered {}, removed {} " +
+                    "out of {}", nextMetadata, node, partitionsToLogString(added),
+                    partitionsToLogString(altered), partitionsToLogString(removed),
+                    partitionsToLogString(sessionPartitions.keySet()));
+            Map<TopicPartition, PartitionData> toSend =
+                Collections.unmodifiableMap(new LinkedHashMap<>(next));
+            Map<TopicPartition, PartitionData> curSessionPartitions =
+                Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions));
+            next = null;
+            return new FetchRequestData(toSend, Collections.unmodifiableList(removed),
+                curSessionPartitions, nextMetadata);
+        }
+    }
+
+    public Builder newBuilder() {
+        return new Builder();
+    }
+
+    private String partitionsToLogString(Collection<TopicPartition> partitions) {
+        if (!log.isTraceEnabled()) {
+            return String.format("%d partition(s)", partitions.size());
+        }
+        return "(" + Utils.join(partitions, ", ") + ")";
+    }
+
+    /**
+     * Return some partitions which are expected to be in a particular set, but which are not.
+     *
+     * @param toFind    The partitions to look for.
+     * @param toSearch  The set of partitions to search.
+     * @return          null if all partitions were found; some of the missing ones
+     *                  in string form, if not.
+     */
+    static Set<TopicPartition> findMissing(Set<TopicPartition> toFind, Set<TopicPartition> toSearch) {
+        Set<TopicPartition> ret = new LinkedHashSet<>();
+        for (TopicPartition partition : toFind) {
+            if (!toSearch.contains(partition)) {
+                ret.add(partition);
+            }
+        }
+        return ret;
+    }
+
+    /**
+     * Verify that a full fetch response contains all the partitions in the fetch session.
+     *
+     * @param response  The response.
+     * @return          True if the full fetch response partitions are valid.
+     */
+    private String verifyFullFetchResponsePartitions(FetchResponse response) {
+        StringBuilder bld = new StringBuilder();
+        Set<TopicPartition> omitted =
+            findMissing(response.responseData().keySet(), sessionPartitions.keySet());
+        Set<TopicPartition> extra =
+            findMissing(sessionPartitions.keySet(), response.responseData().keySet());
+        if (!omitted.isEmpty()) {
+            bld.append("omitted=(").append(Utils.join(omitted, ", ")).append(", ");
+        }
+        if (!extra.isEmpty()) {
+            bld.append("extra=(").append(Utils.join(extra, ", ")).append(", ");
+        }
+        if ((!omitted.isEmpty()) || (!extra.isEmpty())) {
+            bld.append("response=(").append(Utils.join(response.responseData().keySet(), ", "));
+            return bld.toString();
+        }
+        return null;
+    }
+
+    /**
+     * Verify that the partitions in an incremental fetch response are contained in the session.
+     *
+     * @param response  The response.
+     * @return          True if the incremental fetch response partitions are valid.
+     */
+    private String verifyIncrementalFetchResponsePartitions(FetchResponse response) {
+        Set<TopicPartition> extra =
+            findMissing(response.responseData().keySet(), sessionPartitions.keySet());
+        if (!extra.isEmpty()) {
+            StringBuilder bld = new StringBuilder();
+            bld.append("extra=(").append(Utils.join(extra, ", ")).append("), ");
+            bld.append("response=(").append(
+                Utils.join(response.responseData().keySet(), ", ")).append("), ");
+            return bld.toString();
+        }
+        return null;
+    }
+
+    /**
+     * Create a string describing the partitions in a FetchResponse.
+     *
+     * @param response  The FetchResponse.
+     * @return          The string to log.
+     */
+    private String responseDataToLogString(FetchResponse response) {
+        if (!log.isTraceEnabled()) {
+            int implied = sessionPartitions.size() - response.responseData().size();
+            if (implied > 0) {
+                return String.format(" with %d response partition(s), %d implied partition(s)",
+                    response.responseData().size(), implied);
+            } else {
+                return String.format(" with %d response partition(s)",
+                    response.responseData().size());
+            }
+        }
+        StringBuilder bld = new StringBuilder();
+        bld.append(" with response=(").
+            append(Utils.join(response.responseData().keySet(), ", ")).
+            append(")");
+        String prefix = ", implied=(";
+        String suffix = "";
+        for (TopicPartition partition : sessionPartitions.keySet()) {
+            if (!response.responseData().containsKey(partition)) {
+                bld.append(prefix);
+                bld.append(partition);
+                prefix = ", ";
+                suffix = ")";
+            }
+        }
+        bld.append(suffix);
+        return bld.toString();
+    }
+
+    /**
+     * Handle the fetch response.
+     *
+     * @param response  The response.
+     * @return          True if the response is well-formed; false if it can't be processed
+     *                  because of missing or unexpected partitions.
+     */
+    public boolean handleResponse(FetchResponse response) {
+        if (response.error() != Errors.NONE) {
+            log.info("Node {} was unable to process the fetch request with {}: {}.",
+                node, nextMetadata, response.error());
+            if (response.error() == Errors.FETCH_SESSION_ID_NOT_FOUND) {
+                nextMetadata = FetchMetadata.INITIAL;
+            } else {
+                nextMetadata = nextMetadata.nextCloseExisting();
+            }
+            return false;
+        } else if (nextMetadata.isFull()) {
+            String problem = verifyFullFetchResponsePartitions(response);
+            if (problem != null) {
+                log.info("Node {} sent an invalid full fetch response with {}", node, problem);
+                nextMetadata = FetchMetadata.INITIAL;
+                return false;
+            } else if (response.sessionId() == INVALID_SESSION_ID) {
+                log.debug("Node {} sent a full fetch response{}",
+                    node, responseDataToLogString(response));
+                nextMetadata = FetchMetadata.INITIAL;
+                return true;
+            } else {
+                // The server created a new incremental fetch session.
+                log.debug("Node {} sent a full fetch response that created a new incremental " +
+                    "fetch session {}{}", node, response.sessionId(), responseDataToLogString(response));
+                nextMetadata = FetchMetadata.newIncremental(response.sessionId());
+                return true;
+            }
+        } else {
+            String problem = verifyIncrementalFetchResponsePartitions(response);
+            if (problem != null) {
+                log.info("Node {} sent an invalid incremental fetch response with {}", node, problem);
+                nextMetadata = nextMetadata.nextCloseExisting();
+                return false;
+            } else if (response.sessionId() == INVALID_SESSION_ID) {
+                // The incremental fetch session was closed by the server.
+                log.debug("Node {} sent an incremental fetch response closing session {}{}",
+                    node, nextMetadata.sessionId(), responseDataToLogString(response));
+                nextMetadata = FetchMetadata.INITIAL;
+                return true;
+            } else {
+                // The incremental fetch session was continued by the server.
+                log.debug("Node {} sent an incremental fetch response for session {}{}",
+                    node, response.sessionId(), responseDataToLogString(response));
+                nextMetadata = nextMetadata.nextIncremental();
+                return true;
+            }
+        }
+    }
+
+    /**
+     * Handle an error sending the prepared request.
+     *
+     * When a network error occurs, we close any existing fetch session on our next request,
+     * and try to create a new session.
+     *
+     * @param t     The exception.
+     */
+    public void handleError(Throwable t) {
+        log.info("Error sending fetch request {} to node {}: {}.", nextMetadata, node, t.toString());
+        nextMetadata = nextMetadata.nextCloseExisting();
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 6d56139..32782ee 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.clients.consumer.internals;
 
 import org.apache.kafka.clients.ClientResponse;
+import org.apache.kafka.clients.FetchSessionHandler;
 import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
@@ -92,6 +93,7 @@ import static org.apache.kafka.common.serialization.ExtendedDeserializer.Wrapper
  */
 public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
     private final Logger log;
+    private final LogContext logContext;
     private final ConsumerNetworkClient client;
     private final Time time;
     private final int minBytes;
@@ -110,6 +112,7 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
     private final ExtendedDeserializer<K> keyDeserializer;
     private final ExtendedDeserializer<V> valueDeserializer;
     private final IsolationLevel isolationLevel;
+    private final Map<Integer, FetchSessionHandler> sessionHandlers;
 
     private PartitionRecords nextInLineRecords = null;
 
@@ -131,6 +134,7 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
                    long retryBackoffMs,
                    IsolationLevel isolationLevel) {
         this.log = logContext.logger(Fetcher.class);
+        this.logContext = logContext;
         this.time = time;
         this.client = client;
         this.metadata = metadata;
@@ -147,6 +151,7 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
         this.sensors = new FetchManagerMetrics(metrics, metricsRegistry);
         this.retryBackoffMs = retryBackoffMs;
         this.isolationLevel = isolationLevel;
+        this.sessionHandlers = new HashMap<>();
 
         subscriptions.addListener(this);
     }
@@ -181,36 +186,37 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
         return !completedFetches.isEmpty();
     }
 
-    private boolean matchesRequestedPartitions(FetchRequest.Builder request, FetchResponse response) {
-        Set<TopicPartition> requestedPartitions = request.fetchData().keySet();
-        Set<TopicPartition> fetchedPartitions = response.responseData().keySet();
-        return fetchedPartitions.equals(requestedPartitions);
-    }
-
     /**
      * Set-up a fetch request for any node that we have assigned partitions for which doesn't already have
      * an in-flight fetch or pending fetch data.
      * @return number of fetches sent
      */
     public int sendFetches() {
-        Map<Node, FetchRequest.Builder> fetchRequestMap = createFetchRequests();
-        for (Map.Entry<Node, FetchRequest.Builder> fetchEntry : fetchRequestMap.entrySet()) {
-            final FetchRequest.Builder request = fetchEntry.getValue();
-            final Node fetchTarget = fetchEntry.getKey();
-
-            log.debug("Sending {} fetch for partitions {} to broker {}", isolationLevel, request.fetchData().keySet(),
-                    fetchTarget);
+        Map<Node, FetchSessionHandler.FetchRequestData> fetchRequestMap = prepareFetchRequests();
+        for (Map.Entry<Node, FetchSessionHandler.FetchRequestData> entry : fetchRequestMap.entrySet()) {
+            final Node fetchTarget = entry.getKey();
+            final FetchSessionHandler.FetchRequestData data = entry.getValue();
+            final FetchRequest.Builder request = FetchRequest.Builder.
+                forConsumer(this.maxWaitMs, this.minBytes, data.toSend())
+                .isolationLevel(isolationLevel)
+                .setMaxBytes(this.maxBytes)
+                .metadata(data.metadata())
+                .toForget(data.toForget());
+            if (log.isDebugEnabled()) {
+                log.debug("Sending {} {} to broker {}", isolationLevel, data.toString(), fetchTarget);
+            }
             client.send(fetchTarget, request)
                     .addListener(new RequestFutureListener<ClientResponse>() {
                         @Override
                         public void onSuccess(ClientResponse resp) {
                             FetchResponse response = (FetchResponse) resp.responseBody();
-                            if (!matchesRequestedPartitions(request, response)) {
-                                // obviously we expect the broker to always send us valid responses, so this check
-                                // is mainly for test cases where mock fetch responses must be manually crafted.
-                                log.warn("Ignoring fetch response containing partitions {} since it does not match " +
-                                        "the requested partitions {}", response.responseData().keySet(),
-                                        request.fetchData().keySet());
+                            FetchSessionHandler handler = sessionHandlers.get(fetchTarget.id());
+                            if (handler == null) {
+                                log.error("Unable to find FetchSessionHandler for node {}. Ignoring fetch response.",
+                                    fetchTarget.id());
+                                return;
+                            }
+                            if (!handler.handleResponse(response)) {
                                 return;
                             }
 
@@ -219,7 +225,7 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
 
                             for (Map.Entry<TopicPartition, FetchResponse.PartitionData> entry : response.responseData().entrySet()) {
                                 TopicPartition partition = entry.getKey();
-                                long fetchOffset = request.fetchData().get(partition).fetchOffset;
+                                long fetchOffset = data.sessionPartitions().get(partition).fetchOffset;
                                 FetchResponse.PartitionData fetchData = entry.getValue();
 
                                 log.debug("Fetch {} at offset {} for partition {} returned fetch data {}",
@@ -233,7 +239,10 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
 
                         @Override
                         public void onFailure(RuntimeException e) {
-                            log.debug("Fetch request {} to {} failed", request.fetchData(), fetchTarget, e);
+                            FetchSessionHandler handler = sessionHandlers.get(fetchTarget.id());
+                            if (handler != null) {
+                                handler.handleError(e);
+                            }
                         }
                     });
         }
@@ -772,42 +781,41 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
      * Create fetch requests for all nodes for which we have assigned partitions
      * that have no existing requests in flight.
      */
-    private Map<Node, FetchRequest.Builder> createFetchRequests() {
+    private Map<Node, FetchSessionHandler.FetchRequestData> prepareFetchRequests() {
         // create the fetch info
         Cluster cluster = metadata.fetch();
-        Map<Node, LinkedHashMap<TopicPartition, FetchRequest.PartitionData>> fetchable = new LinkedHashMap<>();
+        Map<Node, FetchSessionHandler.Builder> fetchable = new LinkedHashMap<>();
         for (TopicPartition partition : fetchablePartitions()) {
             Node node = cluster.leaderFor(partition);
             if (node == null) {
                 metadata.requestUpdate();
             } else if (!this.client.hasPendingRequests(node)) {
                 // if there is a leader and no in-flight requests, issue a new fetch
-                LinkedHashMap<TopicPartition, FetchRequest.PartitionData> fetch = fetchable.get(node);
-                if (fetch == null) {
-                    fetch = new LinkedHashMap<>();
-                    fetchable.put(node, fetch);
+                FetchSessionHandler.Builder builder = fetchable.get(node);
+                if (builder == null) {
+                    FetchSessionHandler handler = sessionHandlers.get(node.id());
+                    if (handler == null) {
+                        handler = new FetchSessionHandler(logContext, node.id());
+                        sessionHandlers.put(node.id(), handler);
+                    }
+                    builder = handler.newBuilder();
+                    fetchable.put(node, builder);
                 }
 
                 long position = this.subscriptions.position(partition);
-                fetch.put(partition, new FetchRequest.PartitionData(position, FetchRequest.INVALID_LOG_START_OFFSET,
-                        this.fetchSize));
+                builder.add(partition, new FetchRequest.PartitionData(position, FetchRequest.INVALID_LOG_START_OFFSET,
+                    this.fetchSize));
                 log.debug("Added {} fetch request for partition {} at offset {} to node {}", isolationLevel,
-                        partition, position, node);
+                    partition, position, node);
             } else {
                 log.trace("Skipping fetch for partition {} because there is an in-flight request to {}", partition, node);
             }
         }
-
-        // create the fetches
-        Map<Node, FetchRequest.Builder> requests = new HashMap<>();
-        for (Map.Entry<Node, LinkedHashMap<TopicPartition, FetchRequest.PartitionData>> entry : fetchable.entrySet()) {
-            Node node = entry.getKey();
-            FetchRequest.Builder fetch = FetchRequest.Builder.forConsumer(this.maxWaitMs, this.minBytes,
-                    entry.getValue(), isolationLevel)
-                    .setMaxBytes(this.maxBytes);
-            requests.put(node, fetch);
+        Map<Node, FetchSessionHandler.FetchRequestData> reqs = new LinkedHashMap<>();
+        for (Map.Entry<Node, FetchSessionHandler.Builder> entry : fetchable.entrySet()) {
+            reqs.put(entry.getKey(), entry.getValue().build());
         }
-        return requests;
+        return reqs;
     }
 
     /**
diff --git a/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java
new file mode 100644
index 0000000..2ce5f74
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.errors;
+
+public class FetchSessionIdNotFoundException extends RetriableException {
+    private static final long serialVersionUID = 1L;
+
+    public FetchSessionIdNotFoundException() {
+    }
+
+    public FetchSessionIdNotFoundException(String message) {
+        super(message);
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java
new file mode 100644
index 0000000..3b135c0
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.errors;
+
+public class InvalidFetchSessionEpochException extends RetriableException {
+    private static final long serialVersionUID = 1L;
+
+    public InvalidFetchSessionEpochException() {
+    }
+
+    public InvalidFetchSessionEpochException(String message) {
+        super(message);
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java
index e2b8aea..4b44c18 100644
--- a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java
+++ b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java
@@ -30,6 +30,7 @@ import org.apache.kafka.common.errors.DelegationTokenDisabledException;
 import org.apache.kafka.common.errors.DelegationTokenExpiredException;
 import org.apache.kafka.common.errors.DelegationTokenNotFoundException;
 import org.apache.kafka.common.errors.DelegationTokenOwnerMismatchException;
+import org.apache.kafka.common.errors.FetchSessionIdNotFoundException;
 import org.apache.kafka.common.errors.GroupAuthorizationException;
 import org.apache.kafka.common.errors.GroupIdNotFoundException;
 import org.apache.kafka.common.errors.GroupNotEmptyException;
@@ -38,6 +39,7 @@ import org.apache.kafka.common.errors.IllegalSaslStateException;
 import org.apache.kafka.common.errors.InconsistentGroupProtocolException;
 import org.apache.kafka.common.errors.InvalidCommitOffsetSizeException;
 import org.apache.kafka.common.errors.InvalidConfigurationException;
+import org.apache.kafka.common.errors.InvalidFetchSessionEpochException;
 import org.apache.kafka.common.errors.InvalidFetchSizeException;
 import org.apache.kafka.common.errors.InvalidGroupIdException;
 import org.apache.kafka.common.errors.InvalidPartitionsException;
@@ -608,6 +610,20 @@ public enum Errors {
         public ApiException build(String message) {
             return new GroupIdNotFoundException(message);
         }
+    }),
+    FETCH_SESSION_ID_NOT_FOUND(70, "The fetch session ID was not found",
+        new ApiExceptionBuilder() {
+            @Override
+            public ApiException build(String message) {
+                return new FetchSessionIdNotFoundException(message);
+            }
+    }),
+    INVALID_FETCH_SESSION_EPOCH(71, "The fetch session epoch is invalid",
+        new ApiExceptionBuilder() {
+            @Override
+            public ApiException build(String message) {
+                return new InvalidFetchSessionEpochException(message);
+            }
     });
 
     private interface ApiExceptionBuilder {
diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java
index 6fb6b20..ac24a1b 100644
--- a/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java
+++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java
@@ -105,6 +105,12 @@ public class Struct {
         return alternative;
     }
 
+    public Short getOrElse(Field.Int16 field, short alternative) {
+        if (hasField(field.name))
+            return getShort(field.name);
+        return alternative;
+    }
+
     public Integer getOrElse(Field.Int32 field, int alternative) {
         if (hasField(field.name))
             return getInt(field.name);
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java
new file mode 100644
index 0000000..feb6953
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.requests;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Objects;
+
+public class FetchMetadata {
+    public static final Logger log = LoggerFactory.getLogger(FetchMetadata.class);
+
+    /**
+     * The session ID used by clients with no session.
+     */
+    public static final int INVALID_SESSION_ID = 0;
+
+    /**
+     * The first epoch.  When used in a fetch request, indicates that the client
+     * wants to create or recreate a session.
+     */
+    public static final int INITIAL_EPOCH = 0;
+
+    /**
+     * An invalid epoch.  When used in a fetch request, indicates that the client
+     * wants to close any existing session, and not create a new one.
+     */
+    public static final int FINAL_EPOCH = -1;
+
+    /**
+     * The FetchMetadata that is used when initializing a new FetchSessionHandler.
+     */
+    public static final FetchMetadata INITIAL = new FetchMetadata(INVALID_SESSION_ID, INITIAL_EPOCH);
+
+    /**
+     * The FetchMetadata that is implicitly used for handling older FetchRequests that
+     * don't include fetch metadata.
+     */
+    public static final FetchMetadata LEGACY = new FetchMetadata(INVALID_SESSION_ID, FINAL_EPOCH);
+
+    /**
+     * Returns the next epoch.
+     *
+     * @param prevEpoch The previous epoch.
+     * @return          The next epoch.
+     */
+    public static int nextEpoch(int prevEpoch) {
+        if (prevEpoch < 0) {
+            // The next epoch after FINAL_EPOCH is always FINAL_EPOCH itself.
+            return FINAL_EPOCH;
+        } else if (prevEpoch == Integer.MAX_VALUE) {
+            return 1;
+        } else {
+            return prevEpoch + 1;
+        }
+    }
+
+    /**
+     * The fetch session ID.
+     */
+    private final int sessionId;
+
+    /**
+     * The fetch session epoch.
+     */
+    private final int epoch;
+
+    public FetchMetadata(int sessionId, int epoch) {
+        this.sessionId = sessionId;
+        this.epoch = epoch;
+    }
+
+    /**
+     * Returns true if this is a full fetch request.
+     */
+    public boolean isFull() {
+        return (this.epoch == INITIAL_EPOCH) || (this.epoch == FINAL_EPOCH);
+    }
+
+    public int sessionId() {
+        return sessionId;
+    }
+
+    public int epoch() {
+        return epoch;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(sessionId, epoch);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        FetchMetadata that = (FetchMetadata) o;
+        return sessionId == that.sessionId && epoch == that.epoch;
+    }
+
+    /**
+     * Return the metadata for the next error response.
+     */
+    public FetchMetadata nextCloseExisting() {
+        return new FetchMetadata(sessionId, INITIAL_EPOCH);
+    }
+
+    /**
+     * Return the metadata for the next full fetch request.
+     */
+    public static FetchMetadata newIncremental(int sessionId) {
+        return new FetchMetadata(sessionId, nextEpoch(INITIAL_EPOCH));
+    }
+
+    /**
+     * Return the metadata for the next incremental response.
+     */
+    public FetchMetadata nextIncremental() {
+        return new FetchMetadata(sessionId, nextEpoch(epoch));
+    }
+
+    @Override
+    public String toString() {
+        StringBuilder bld = new StringBuilder();
+        if (sessionId == INVALID_SESSION_ID) {
+            bld.append("(sessionId=INVALID, ");
+        } else {
+            bld.append("(sessionId=").append(sessionId).append(", ");
+        }
+        if (epoch == INITIAL_EPOCH) {
+            bld.append("epoch=INITIAL)");
+        } else if (epoch == FINAL_EPOCH) {
+            bld.append("epoch=FINAL)");
+        } else {
+            bld.append("epoch=").append(epoch).append(")");
+        }
+        return bld.toString();
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
index 18425d0..65cf7fe 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
@@ -23,19 +23,27 @@ import org.apache.kafka.common.protocol.types.ArrayOf;
 import org.apache.kafka.common.protocol.types.Field;
 import org.apache.kafka.common.protocol.types.Schema;
 import org.apache.kafka.common.protocol.types.Struct;
+import org.apache.kafka.common.protocol.types.Type;
 import org.apache.kafka.common.record.MemoryRecords;
+import org.apache.kafka.common.utils.Utils;
 
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 import static org.apache.kafka.common.protocol.CommonFields.PARTITION_ID;
 import static org.apache.kafka.common.protocol.CommonFields.TOPIC_NAME;
 import static org.apache.kafka.common.protocol.types.Type.INT32;
 import static org.apache.kafka.common.protocol.types.Type.INT64;
 import static org.apache.kafka.common.protocol.types.Type.INT8;
+import static org.apache.kafka.common.requests.FetchMetadata.FINAL_EPOCH;
+import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
 
 public class FetchRequest extends AbstractRequest {
     public static final int CONSUMER_REPLICA_ID = -1;
@@ -44,6 +52,7 @@ public class FetchRequest extends AbstractRequest {
     private static final String MIN_BYTES_KEY_NAME = "min_bytes";
     private static final String ISOLATION_LEVEL_KEY_NAME = "isolation_level";
     private static final String TOPICS_KEY_NAME = "topics";
+    private static final String FORGOTTEN_TOPICS_DATA = "forgetten_topics_data";
 
     // request and partition level name
     private static final String MAX_BYTES_KEY_NAME = "max_bytes";
@@ -139,9 +148,36 @@ public class FetchRequest extends AbstractRequest {
      */
     private static final Schema FETCH_REQUEST_V6 = FETCH_REQUEST_V5;
 
+    // FETCH_REQUEST_V7 added incremental fetch requests.
+    public static final Field.Int32 SESSION_ID = new Field.Int32("session_id", "The fetch session ID");
+    public static final Field.Int32 EPOCH = new Field.Int32("epoch", "The fetch epoch");
+
+    private static final Schema FORGOTTEN_TOPIC_DATA = new Schema(
+        TOPIC_NAME,
+        new Field(PARTITIONS_KEY_NAME, new ArrayOf(Type.INT32),
+            "Partitions to remove from the fetch session."));
+
+    private static final Schema FETCH_REQUEST_V7 = new Schema(
+        new Field(REPLICA_ID_KEY_NAME, INT32, "Broker id of the follower. For normal consumers, use -1."),
+        new Field(MAX_WAIT_KEY_NAME, INT32, "Maximum time in ms to wait for the response."),
+        new Field(MIN_BYTES_KEY_NAME, INT32, "Minimum bytes to accumulate in the response."),
+        new Field(MAX_BYTES_KEY_NAME, INT32, "Maximum bytes to accumulate in the response. Note that this is not an absolute maximum, " +
+            "if the first message in the first non-empty partition of the fetch is larger than this " +
+            "value, the message will still be returned to ensure that progress can be made."),
+        new Field(ISOLATION_LEVEL_KEY_NAME, INT8, "This setting controls the visibility of transactional records. Using READ_UNCOMMITTED " +
+            "(isolation_level = 0) makes all records visible. With READ_COMMITTED (isolation_level = 1), " +
+            "non-transactional and COMMITTED transactional records are visible. To be more concrete, " +
+            "READ_COMMITTED returns all data from offsets smaller than the current LSO (last stable offset), " +
+            "and enables the inclusion of the list of aborted transactions in the result, which allows " +
+            "consumers to discard ABORTED transactional records"),
+        SESSION_ID,
+        EPOCH,
+        new Field(TOPICS_KEY_NAME, new ArrayOf(FETCH_REQUEST_TOPIC_V5), "Topics to fetch in the order provided."),
+        new Field(FORGOTTEN_TOPICS_DATA, new ArrayOf(FORGOTTEN_TOPIC_DATA), "Topics to remove from the fetch session."));
+
     public static Schema[] schemaVersions() {
         return new Schema[]{FETCH_REQUEST_V0, FETCH_REQUEST_V1, FETCH_REQUEST_V2, FETCH_REQUEST_V3, FETCH_REQUEST_V4,
-            FETCH_REQUEST_V5, FETCH_REQUEST_V6};
+            FETCH_REQUEST_V5, FETCH_REQUEST_V6, FETCH_REQUEST_V7};
     };
 
     // default values for older versions where a request level limit did not exist
@@ -153,7 +189,14 @@ public class FetchRequest extends AbstractRequest {
     private final int minBytes;
     private final int maxBytes;
     private final IsolationLevel isolationLevel;
-    private final LinkedHashMap<TopicPartition, PartitionData> fetchData;
+
+    // Note: the iteration order of this map is significant, since it determines the order
+    // in which partitions appear in the message.  For this reason, this map should have a
+    // deterministic iteration order, like LinkedHashMap or TreeMap (but unlike HashMap).
+    private final Map<TopicPartition, PartitionData> fetchData;
+
+    private final List<TopicPartition> toForget;
+    private final FetchMetadata metadata;
 
     public static final class PartitionData {
         public final long fetchOffset;
@@ -170,6 +213,21 @@ public class FetchRequest extends AbstractRequest {
         public String toString() {
             return "(offset=" + fetchOffset + ", logStartOffset=" + logStartOffset + ", maxBytes=" + maxBytes + ")";
         }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(fetchOffset, logStartOffset, maxBytes);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            PartitionData that = (PartitionData) o;
+            return Objects.equals(fetchOffset, that.fetchOffset) &&
+                Objects.equals(logStartOffset, that.logStartOffset) &&
+                Objects.equals(maxBytes, that.maxBytes);
+        }
     }
 
     static final class TopicAndPartitionData<T> {
@@ -181,9 +239,10 @@ public class FetchRequest extends AbstractRequest {
             this.partitions = new LinkedHashMap<>();
         }
 
-        public static <T> List<TopicAndPartitionData<T>> batchByTopic(LinkedHashMap<TopicPartition, T> data) {
+        public static <T> List<TopicAndPartitionData<T>> batchByTopic(Iterator<Map.Entry<TopicPartition, T>> iter) {
             List<TopicAndPartitionData<T>> topics = new ArrayList<>();
-            for (Map.Entry<TopicPartition, T> topicEntry : data.entrySet()) {
+            while (iter.hasNext()) {
+                Map.Entry<TopicPartition, T> topicEntry = iter.next();
                 String topic = topicEntry.getKey().topic();
                 int partition = topicEntry.getKey().partition();
                 T partitionData = topicEntry.getValue();
@@ -199,37 +258,42 @@ public class FetchRequest extends AbstractRequest {
         private final int maxWait;
         private final int minBytes;
         private final int replicaId;
-        private final LinkedHashMap<TopicPartition, PartitionData> fetchData;
-        private final IsolationLevel isolationLevel;
+        private final Map<TopicPartition, PartitionData> fetchData;
+        private IsolationLevel isolationLevel = IsolationLevel.READ_UNCOMMITTED;
         private int maxBytes = DEFAULT_RESPONSE_MAX_BYTES;
+        private FetchMetadata metadata = FetchMetadata.LEGACY;
+        private List<TopicPartition> toForget = Collections.<TopicPartition>emptyList();
 
-        public static Builder forConsumer(int maxWait, int minBytes, LinkedHashMap<TopicPartition, PartitionData> fetchData) {
-            return forConsumer(maxWait, minBytes, fetchData, IsolationLevel.READ_UNCOMMITTED);
-        }
-
-        public static Builder forConsumer(int maxWait, int minBytes, LinkedHashMap<TopicPartition, PartitionData> fetchData,
-                                          IsolationLevel isolationLevel) {
-            return new Builder(ApiKeys.FETCH.oldestVersion(), ApiKeys.FETCH.latestVersion(), CONSUMER_REPLICA_ID,
-                    maxWait, minBytes, fetchData, isolationLevel);
+        public static Builder forConsumer(int maxWait, int minBytes, Map<TopicPartition, PartitionData> fetchData) {
+            return new Builder(ApiKeys.FETCH.oldestVersion(), ApiKeys.FETCH.latestVersion(),
+                CONSUMER_REPLICA_ID, maxWait, minBytes, fetchData);
         }
 
         public static Builder forReplica(short allowedVersion, int replicaId, int maxWait, int minBytes,
-                                         LinkedHashMap<TopicPartition, PartitionData> fetchData) {
-            return new Builder(allowedVersion, allowedVersion, replicaId, maxWait, minBytes, fetchData,
-                    IsolationLevel.READ_UNCOMMITTED);
+                                         Map<TopicPartition, PartitionData> fetchData) {
+            return new Builder(allowedVersion, allowedVersion, replicaId, maxWait, minBytes, fetchData);
         }
 
-        private Builder(short minVersion, short maxVersion, int replicaId, int maxWait, int minBytes,
-                        LinkedHashMap<TopicPartition, PartitionData> fetchData, IsolationLevel isolationLevel) {
+        public Builder(short minVersion, short maxVersion, int replicaId, int maxWait, int minBytes,
+                        Map<TopicPartition, PartitionData> fetchData) {
             super(ApiKeys.FETCH, minVersion, maxVersion);
             this.replicaId = replicaId;
             this.maxWait = maxWait;
             this.minBytes = minBytes;
             this.fetchData = fetchData;
+        }
+
+        public Builder isolationLevel(IsolationLevel isolationLevel) {
             this.isolationLevel = isolationLevel;
+            return this;
+        }
+
+        public Builder metadata(FetchMetadata metadata) {
+            this.metadata = metadata;
+            return this;
         }
 
-        public LinkedHashMap<TopicPartition, PartitionData> fetchData() {
+        public Map<TopicPartition, PartitionData> fetchData() {
             return this.fetchData;
         }
 
@@ -238,13 +302,23 @@ public class FetchRequest extends AbstractRequest {
             return this;
         }
 
+        public List<TopicPartition> toForget() {
+            return toForget;
+        }
+
+        public Builder toForget(List<TopicPartition> toForget) {
+            this.toForget = toForget;
+            return this;
+        }
+
         @Override
         public FetchRequest build(short version) {
             if (version < 3) {
                 maxBytes = DEFAULT_RESPONSE_MAX_BYTES;
             }
 
-            return new FetchRequest(version, replicaId, maxWait, minBytes, maxBytes, fetchData, isolationLevel);
+            return new FetchRequest(version, replicaId, maxWait, minBytes, maxBytes, fetchData,
+                isolationLevel, toForget, metadata);
         }
 
         @Override
@@ -257,13 +331,16 @@ public class FetchRequest extends AbstractRequest {
                     append(", maxBytes=").append(maxBytes).
                     append(", fetchData=").append(fetchData).
                     append(", isolationLevel=").append(isolationLevel).
+                    append(", toForget=").append(Utils.join(toForget, ", ")).
+                    append(", metadata=").append(metadata).
                     append(")");
             return bld.toString();
         }
     }
 
     private FetchRequest(short version, int replicaId, int maxWait, int minBytes, int maxBytes,
-                         LinkedHashMap<TopicPartition, PartitionData> fetchData, IsolationLevel isolationLevel) {
+                         Map<TopicPartition, PartitionData> fetchData, IsolationLevel isolationLevel,
+                         List<TopicPartition> toForget, FetchMetadata metadata) {
         super(version);
         this.replicaId = replicaId;
         this.maxWait = maxWait;
@@ -271,6 +348,8 @@ public class FetchRequest extends AbstractRequest {
         this.maxBytes = maxBytes;
         this.fetchData = fetchData;
         this.isolationLevel = isolationLevel;
+        this.toForget = toForget;
+        this.metadata = metadata;
     }
 
     public FetchRequest(Struct struct, short version) {
@@ -282,11 +361,23 @@ public class FetchRequest extends AbstractRequest {
             maxBytes = struct.getInt(MAX_BYTES_KEY_NAME);
         else
             maxBytes = DEFAULT_RESPONSE_MAX_BYTES;
-
         if (struct.hasField(ISOLATION_LEVEL_KEY_NAME))
             isolationLevel = IsolationLevel.forId(struct.getByte(ISOLATION_LEVEL_KEY_NAME));
         else
             isolationLevel = IsolationLevel.READ_UNCOMMITTED;
+        toForget = new ArrayList<>(0);
+        if (struct.hasField(FORGOTTEN_TOPICS_DATA)) {
+            for (Object forgottenTopicObj : struct.getArray(FORGOTTEN_TOPICS_DATA)) {
+                Struct forgottenTopic = (Struct) forgottenTopicObj;
+                String topicName = forgottenTopic.get(TOPIC_NAME);
+                for (Object partObj : forgottenTopic.getArray(PARTITIONS_KEY_NAME)) {
+                    Integer part = (Integer) partObj;
+                    toForget.add(new TopicPartition(topicName, part));
+                }
+            }
+        }
+        metadata = new FetchMetadata(struct.getOrElse(SESSION_ID, INVALID_SESSION_ID),
+            struct.getOrElse(EPOCH, FINAL_EPOCH));
 
         fetchData = new LinkedHashMap<>();
         for (Object topicResponseObj : struct.getArray(TOPICS_KEY_NAME)) {
@@ -307,15 +398,21 @@ public class FetchRequest extends AbstractRequest {
 
     @Override
     public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) {
+        // The error is indicated in two ways: by setting the same error code in all partitions, and by
+        // setting the top-level error code.  The form where we set the same error code in all partitions
+        // is needed in order to maintain backwards compatibility with older versions of the protocol
+        // in which there was no top-level error code. Note that for incremental fetch responses, there
+        // may not be any partitions at all in the response.  For this reason, the top-level error code
+        // is essential for them.
+        Errors error = Errors.forException(e);
         LinkedHashMap<TopicPartition, FetchResponse.PartitionData> responseData = new LinkedHashMap<>();
-
-        for (Map.Entry<TopicPartition, PartitionData> entry: fetchData.entrySet()) {
-            FetchResponse.PartitionData partitionResponse = new FetchResponse.PartitionData(Errors.forException(e),
-                FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET,
-                null, MemoryRecords.EMPTY);
+        for (Map.Entry<TopicPartition, PartitionData> entry : fetchData.entrySet()) {
+            FetchResponse.PartitionData partitionResponse = new FetchResponse.PartitionData(error,
+                FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
+                FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY);
             responseData.put(entry.getKey(), partitionResponse);
         }
-        return new FetchResponse(responseData, throttleTimeMs);
+        return new FetchResponse(error, responseData, throttleTimeMs, metadata.sessionId());
     }
 
     public int replicaId() {
@@ -338,6 +435,10 @@ public class FetchRequest extends AbstractRequest {
         return fetchData;
     }
 
+    public List<TopicPartition> toForget() {
+        return toForget;
+    }
+
     public boolean isFromFollower() {
         return replicaId >= 0;
     }
@@ -346,6 +447,10 @@ public class FetchRequest extends AbstractRequest {
         return isolationLevel;
     }
 
+    public FetchMetadata metadata() {
+        return metadata;
+    }
+
     public static FetchRequest parse(ByteBuffer buffer, short version) {
         return new FetchRequest(ApiKeys.FETCH.parseRequest(version, buffer), version);
     }
@@ -353,7 +458,8 @@ public class FetchRequest extends AbstractRequest {
     @Override
     protected Struct toStruct() {
         Struct struct = new Struct(ApiKeys.FETCH.requestSchema(version()));
-        List<TopicAndPartitionData<PartitionData>> topicsData = TopicAndPartitionData.batchByTopic(fetchData);
+        List<TopicAndPartitionData<PartitionData>> topicsData =
+            TopicAndPartitionData.batchByTopic(fetchData.entrySet().iterator());
 
         struct.set(REPLICA_ID_KEY_NAME, replicaId);
         struct.set(MAX_WAIT_KEY_NAME, maxWait);
@@ -362,6 +468,8 @@ public class FetchRequest extends AbstractRequest {
             struct.set(MAX_BYTES_KEY_NAME, maxBytes);
         if (struct.hasField(ISOLATION_LEVEL_KEY_NAME))
             struct.set(ISOLATION_LEVEL_KEY_NAME, isolationLevel.id());
+        struct.setIfExists(SESSION_ID, metadata.sessionId());
+        struct.setIfExists(EPOCH, metadata.epoch());
 
         List<Struct> topicArray = new ArrayList<>();
         for (TopicAndPartitionData<PartitionData> topicEntry : topicsData) {
@@ -382,6 +490,25 @@ public class FetchRequest extends AbstractRequest {
             topicArray.add(topicData);
         }
         struct.set(TOPICS_KEY_NAME, topicArray.toArray());
+        if (struct.hasField(FORGOTTEN_TOPICS_DATA)) {
+            Map<String, List<Integer>> topicsToPartitions = new HashMap<>();
+            for (TopicPartition part : toForget) {
+                List<Integer> partitions = topicsToPartitions.get(part.topic());
+                if (partitions == null) {
+                    partitions = new ArrayList<>();
+                    topicsToPartitions.put(part.topic(), partitions);
+                }
+                partitions.add(part.partition());
+            }
+            List<Struct> toForgetStructs = new ArrayList<>();
+            for (Map.Entry<String, List<Integer>> entry : topicsToPartitions.entrySet()) {
+                Struct toForgetStruct = struct.instance(FORGOTTEN_TOPICS_DATA);
+                toForgetStruct.set(TOPIC_NAME, entry.getKey());
+                toForgetStruct.set(PARTITIONS_KEY_NAME, entry.getValue().toArray());
+                toForgetStructs.add(toForgetStruct);
+            }
+            struct.set(FORGOTTEN_TOPICS_DATA, toForgetStructs.toArray());
+        }
         return struct;
     }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java
index 0d09027..98c6be3 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java
@@ -31,6 +31,7 @@ import org.apache.kafka.common.record.Records;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -42,6 +43,7 @@ import static org.apache.kafka.common.protocol.CommonFields.TOPIC_NAME;
 import static org.apache.kafka.common.protocol.types.Type.INT64;
 import static org.apache.kafka.common.protocol.types.Type.RECORDS;
 import static org.apache.kafka.common.protocol.types.Type.STRING;
+import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
 
 /**
  * This wrapper supports all versions of the Fetch API
@@ -148,9 +150,19 @@ public class FetchResponse extends AbstractResponse {
      */
     private static final Schema FETCH_RESPONSE_V6 = FETCH_RESPONSE_V5;
 
+    // FETCH_REESPONSE_V7 added incremental fetch responses and a top-level error code.
+    public static final Field.Int32 SESSION_ID = new Field.Int32("session_id", "The fetch session ID");
+
+    private static final Schema FETCH_RESPONSE_V7 = new Schema(
+        THROTTLE_TIME_MS,
+        ERROR_CODE,
+        SESSION_ID,
+        new Field(RESPONSES_KEY_NAME, new ArrayOf(FETCH_RESPONSE_TOPIC_V5)));
+
     public static Schema[] schemaVersions() {
         return new Schema[] {FETCH_RESPONSE_V0, FETCH_RESPONSE_V1, FETCH_RESPONSE_V2,
-            FETCH_RESPONSE_V3, FETCH_RESPONSE_V4, FETCH_RESPONSE_V5, FETCH_RESPONSE_V6};
+            FETCH_RESPONSE_V3, FETCH_RESPONSE_V4, FETCH_RESPONSE_V5, FETCH_RESPONSE_V6,
+            FETCH_RESPONSE_V7};
     }
 
 
@@ -168,8 +180,10 @@ public class FetchResponse extends AbstractResponse {
      *  UNKNOWN (-1)
      */
 
-    private final LinkedHashMap<TopicPartition, PartitionData> responseData;
     private final int throttleTimeMs;
+    private final Errors error;
+    private final int sessionId;
+    private final LinkedHashMap<TopicPartition, PartitionData> responseData;
 
     public static final class AbortedTransaction {
         public final long producerId;
@@ -268,17 +282,20 @@ public class FetchResponse extends AbstractResponse {
     }
 
     /**
-     * Constructor for all versions.
-     *
      * From version 3 or later, the entries in `responseData` should be in the same order as the entries in
      * `FetchRequest.fetchData`.
      *
-     * @param responseData fetched data grouped by topic-partition
-     * @param throttleTimeMs Time in milliseconds the response was throttled
+     * @param error             The top-level error code.
+     * @param responseData      The fetched data grouped by partition.
+     * @param throttleTimeMs    The time in milliseconds that the response was throttled
+     * @param sessionId         The fetch session id.
      */
-    public FetchResponse(LinkedHashMap<TopicPartition, PartitionData> responseData, int throttleTimeMs) {
+    public FetchResponse(Errors error, LinkedHashMap<TopicPartition, PartitionData> responseData,
+                         int throttleTimeMs, int sessionId) {
+        this.error = error;
         this.responseData = responseData;
         this.throttleTimeMs = throttleTimeMs;
+        this.sessionId = sessionId;
     }
 
     public FetchResponse(Struct struct) {
@@ -316,17 +333,19 @@ public class FetchResponse extends AbstractResponse {
                 }
 
                 PartitionData partitionData = new PartitionData(error, highWatermark, lastStableOffset, logStartOffset,
-                        abortedTransactions, records);
+                    abortedTransactions, records);
                 responseData.put(new TopicPartition(topic, partition), partitionData);
             }
         }
         this.responseData = responseData;
         this.throttleTimeMs = struct.getOrElse(THROTTLE_TIME_MS, DEFAULT_THROTTLE_TIME);
+        this.error = Errors.forCode(struct.getOrElse(ERROR_CODE, (short) 0));
+        this.sessionId = struct.getOrElse(SESSION_ID, INVALID_SESSION_ID);
     }
 
     @Override
     public Struct toStruct(short version) {
-        return toStruct(version, responseData, throttleTimeMs);
+        return toStruct(version, throttleTimeMs, error, responseData.entrySet().iterator(), sessionId);
     }
 
     @Override
@@ -346,6 +365,10 @@ public class FetchResponse extends AbstractResponse {
         return new MultiSend(dest, sends);
     }
 
+    public Errors error() {
+        return error;
+    }
+
     public LinkedHashMap<TopicPartition, PartitionData> responseData() {
         return responseData;
     }
@@ -354,6 +377,10 @@ public class FetchResponse extends AbstractResponse {
         return this.throttleTimeMs;
     }
 
+    public int sessionId() {
+        return sessionId;
+    }
+
     @Override
     public Map<Errors, Integer> errorCounts() {
         Map<Errors, Integer> errorCounts = new HashMap<>();
@@ -369,7 +396,15 @@ public class FetchResponse extends AbstractResponse {
     private static void addResponseData(Struct struct, int throttleTimeMs, String dest, List<Send> sends) {
         Object[] allTopicData = struct.getArray(RESPONSES_KEY_NAME);
 
-        if (struct.hasField(THROTTLE_TIME_MS)) {
+        if (struct.hasField(ERROR_CODE)) {
+            ByteBuffer buffer = ByteBuffer.allocate(14);
+            buffer.putInt(throttleTimeMs);
+            buffer.putShort(struct.get(ERROR_CODE));
+            buffer.putInt(struct.get(SESSION_ID));
+            buffer.putInt(allTopicData.length);
+            buffer.rewind();
+            sends.add(new ByteBufferSend(dest, buffer));
+        } else if (struct.hasField(THROTTLE_TIME_MS)) {
             ByteBuffer buffer = ByteBuffer.allocate(8);
             buffer.putInt(throttleTimeMs);
             buffer.putInt(allTopicData.length);
@@ -416,9 +451,14 @@ public class FetchResponse extends AbstractResponse {
         sends.add(new RecordsSend(dest, records));
     }
 
-    private static Struct toStruct(short version, LinkedHashMap<TopicPartition, PartitionData> responseData, int throttleTimeMs) {
+    private static Struct toStruct(short version, int throttleTimeMs, Errors error,
+            Iterator<Map.Entry<TopicPartition, PartitionData>> partIterator, int sessionId) {
         Struct struct = new Struct(ApiKeys.FETCH.responseSchema(version));
-        List<FetchRequest.TopicAndPartitionData<PartitionData>> topicsData = FetchRequest.TopicAndPartitionData.batchByTopic(responseData);
+        struct.setIfExists(THROTTLE_TIME_MS, throttleTimeMs);
+        struct.setIfExists(ERROR_CODE, error.code());
+        struct.setIfExists(SESSION_ID, sessionId);
+        List<FetchRequest.TopicAndPartitionData<PartitionData>> topicsData =
+            FetchRequest.TopicAndPartitionData.batchByTopic(partIterator);
         List<Struct> topicArray = new ArrayList<>();
         for (FetchRequest.TopicAndPartitionData<PartitionData> topicEntry: topicsData) {
             Struct topicData = struct.instance(RESPONSES_KEY_NAME);
@@ -466,13 +506,20 @@ public class FetchResponse extends AbstractResponse {
             topicArray.add(topicData);
         }
         struct.set(RESPONSES_KEY_NAME, topicArray.toArray());
-        struct.setIfExists(THROTTLE_TIME_MS, throttleTimeMs);
-
         return struct;
     }
 
-    public static int sizeOf(short version, LinkedHashMap<TopicPartition, PartitionData> responseData) {
-        return 4 + toStruct(version, responseData, 0).sizeOf();
+    /**
+     * Convenience method to find the size of a response.
+     *
+     * @param version       The version of the response to use.
+     * @param partIterator  The partition iterator.
+     * @return              The response size in bytes.
+     */
+    public static int sizeOf(short version, Iterator<Map.Entry<TopicPartition, PartitionData>> partIterator) {
+        // Since the throttleTimeMs and metadata field sizes are constant and fixed, we can
+        // use arbitrary values here without affecting the result.
+        return 4 + toStruct(version, 0, Errors.NONE, partIterator, INVALID_SESSION_ID).sizeOf();
     }
 
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashSet.java b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashSet.java
new file mode 100644
index 0000000..701684d
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashSet.java
@@ -0,0 +1,354 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.utils;
+
+import java.util.AbstractSet;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+/**
+ * A LinkedHashSet which is more memory-efficient than the standard implementation.
+ *
+ * This set preserves the order of insertion.  The order of iteration will always be
+ * the order of insertion.
+ *
+ * This collection requires previous and next indexes to be embedded into each
+ * element.  Using array indices rather than pointers saves space on large heaps
+ * where pointer compression is not in use.  It also reduces the amount of time
+ * the garbage collector has to spend chasing pointers.
+ *
+ * This class uses linear probing.  Unlike HashMap (but like HashTable), we don't force
+ * the size to be a power of 2.  This saves memory.
+ *
+ * This class does not have internal synchronization.
+ */
+@SuppressWarnings("unchecked")
+public class ImplicitLinkedHashSet<E extends ImplicitLinkedHashSet.Element> extends AbstractSet<E> {
+    public interface Element {
+        int prev();
+        void setPrev(int e);
+        int next();
+        void setNext(int e);
+    }
+
+    private static final int HEAD_INDEX = -1;
+
+    public static final int INVALID_INDEX = -2;
+
+    private static class HeadElement implements Element {
+        private int prev = HEAD_INDEX;
+        private int next = HEAD_INDEX;
+
+        @Override
+        public int prev() {
+            return prev;
+        }
+
+        @Override
+        public void setPrev(int prev) {
+            this.prev = prev;
+        }
+
+        @Override
+        public int next() {
+            return next;
+        }
+
+        @Override
+        public void setNext(int next) {
+            this.next = next;
+        }
+    }
+
+    private static Element indexToElement(Element head, Element[] elements, int index) {
+        if (index == HEAD_INDEX) {
+            return head;
+        }
+        return elements[index];
+    }
+
+    private static void addToListTail(Element head, Element[] elements, int elementIdx) {
+        int oldTailIdx = head.prev();
+        Element element = indexToElement(head, elements, elementIdx);
+        Element oldTail = indexToElement(head, elements, oldTailIdx);
+        head.setPrev(elementIdx);
+        oldTail.setNext(elementIdx);
+        element.setPrev(oldTailIdx);
+        element.setNext(HEAD_INDEX);
+    }
+
+    private static void removeFromList(Element head, Element[] elements, int elementIdx) {
+        Element element = indexToElement(head, elements, elementIdx);
+        elements[elementIdx] = null;
+        int prevIdx = element.prev();
+        int nextIdx = element.next();
+        Element prev = indexToElement(head, elements, prevIdx);
+        Element next = indexToElement(head, elements, nextIdx);
+        prev.setNext(nextIdx);
+        next.setPrev(prevIdx);
+        element.setNext(INVALID_INDEX);
+        element.setPrev(INVALID_INDEX);
+    }
+
+    private class ImplicitLinkedHashSetIterator implements Iterator<E> {
+        private Element cur = head;
+
+        private Element next = indexToElement(head, elements, head.next());
+
+        @Override
+        public boolean hasNext() {
+            return next != head;
+        }
+
+        @Override
+        public E next() {
+            if (next == head) {
+                throw new NoSuchElementException();
+            }
+            cur = next;
+            next = indexToElement(head, elements, cur.next());
+            return (E) cur;
+        }
+
+        @Override
+        public void remove() {
+            if (cur == head) {
+                throw new IllegalStateException();
+            }
+            ImplicitLinkedHashSet.this.remove(cur);
+            cur = head;
+        }
+    }
+
+    private Element head;
+
+    private Element[] elements;
+
+    private int size;
+
+    @Override
+    public Iterator<E> iterator() {
+        return new ImplicitLinkedHashSetIterator();
+    }
+
+    private static int slot(Element[] curElements, Element e) {
+        return (e.hashCode() & 0x7fffffff) % curElements.length;
+    }
+
+    /**
+     * Find an element matching an example element.
+     *
+     * Using the element's hash code, we can look up the slot where it belongs.
+     * However, it may not have ended up in exactly this slot, due to a collision.
+     * Therefore, we must search forward in the array until we hit a null, before
+     * concluding that the element is not present.
+     *
+     * @param example   The element to match.
+     * @return          The match index, or INVALID_INDEX if no match was found.
+     */
+    private int findIndex(E example) {
+        int slot = slot(elements, example);
+        for (int seen = 0; seen < elements.length; seen++) {
+            Element element = elements[slot];
+            if (element == null) {
+                return INVALID_INDEX;
+            }
+            if (element.equals(example)) {
+                return slot;
+            }
+            slot = (slot + 1) % elements.length;
+        }
+        return INVALID_INDEX;
+    }
+
+    /**
+     * Find the element which equals() the given example element.
+     *
+     * @param example   The example element.
+     * @return          Null if no element was found; the element, otherwise.
+     */
+    public E find(E example) {
+        int index = findIndex(example);
+        if (index == INVALID_INDEX) {
+            return null;
+        }
+        return (E) elements[index];
+    }
+
+    /**
+     * Returns the number of elements in the set.
+     */
+    @Override
+    public int size() {
+        return size;
+    }
+
+    @Override
+    public boolean contains(Object o) {
+        E example = null;
+        try {
+            example = (E) o;
+        } catch (ClassCastException e) {
+            return false;
+        }
+        return find(example) != null;
+    }
+
+    @Override
+    public boolean add(E newElement) {
+        if ((size + 1) >= elements.length / 2) {
+            // Avoid using even-sized capacities, to get better key distribution.
+            changeCapacity((2 * elements.length) + 1);
+        }
+        int slot = addInternal(newElement, elements);
+        if (slot >= 0) {
+            addToListTail(head, elements, slot);
+            size++;
+            return true;
+        }
+        return false;
+    }
+
+    public void mustAdd(E newElement) {
+        if (!add(newElement)) {
+            throw new RuntimeException("Unable to add " + newElement);
+        }
+    }
+
+    /**
+     * Adds a new element to the appropriate place in the elements array.
+     *
+     * @param newElement    The new element to add.
+     * @param addElements   The elements array.
+     * @return              The index at which the element was inserted, or INVALID_INDEX
+     *                      if the element could not be inserted because there was already
+     *                      an equivalent element.
+     */
+    private static int addInternal(Element newElement, Element[] addElements) {
+        int slot = slot(addElements, newElement);
+        for (int seen = 0; seen < addElements.length; seen++) {
+            Element element = addElements[slot];
+            if (element == null) {
+                addElements[slot] = newElement;
+                return slot;
+            }
+            if (element.equals(newElement)) {
+                return INVALID_INDEX;
+            }
+            slot = (slot + 1) % addElements.length;
+        }
+        throw new RuntimeException("Not enough hash table slots to add a new element.");
+    }
+
+    private void changeCapacity(int newCapacity) {
+        Element[] newElements = new Element[newCapacity];
+        HeadElement newHead = new HeadElement();
+        int oldSize = size;
+        for (Iterator<E> iter = iterator(); iter.hasNext(); ) {
+            Element element = iter.next();
+            iter.remove();
+            int newSlot = addInternal(element, newElements);
+            addToListTail(newHead, newElements, newSlot);
+        }
+        this.elements = newElements;
+        this.head = newHead;
+        this.size = oldSize;
+    }
+
+    @Override
+    public boolean remove(Object o) {
+        E example = null;
+        try {
+            example = (E) o;
+        } catch (ClassCastException e) {
+            return false;
+        }
+        int slot = findIndex(example);
+        if (slot == INVALID_INDEX) {
+            return false;
+        }
+        size--;
+        removeFromList(head, elements, slot);
+        slot = (slot + 1) % elements.length;
+
+        // Find the next empty slot
+        int endSlot = slot;
+        for (int seen = 0; seen < elements.length; seen++) {
+            Element element = elements[endSlot];
+            if (element == null) {
+                break;
+            }
+            endSlot = (endSlot + 1) % elements.length;
+        }
+
+        // We must preserve the denseness invariant.  The denseness invariant says that
+        // any element is either in the slot indicated by its hash code, or a slot which
+        // is not separated from that slot by any nulls.
+        // Reseat all elements in between the deleted element and the next empty slot.
+        while (slot != endSlot) {
+            reseat(slot);
+            slot = (slot + 1) % elements.length;
+        }
+        return true;
+    }
+
+    private void reseat(int prevSlot) {
+        Element element = elements[prevSlot];
+        int newSlot = slot(elements, element);
+        for (int seen = 0; seen < elements.length; seen++) {
+            Element e = elements[newSlot];
+            if ((e == null) || (e == element)) {
+                break;
+            }
+            newSlot = (newSlot + 1) % elements.length;
+        }
+        if (newSlot == prevSlot) {
+            return;
+        }
+        Element prev = indexToElement(head, elements, element.prev());
+        prev.setNext(newSlot);
+        Element next = indexToElement(head, elements, element.next());
+        next.setPrev(newSlot);
+        elements[prevSlot] = null;
+        elements[newSlot] = element;
+    }
+
+    @Override
+    public void clear() {
+        reset(elements.length);
+    }
+
+    public ImplicitLinkedHashSet() {
+        this(5);
+    }
+
+    public ImplicitLinkedHashSet(int initialCapacity) {
+        reset(initialCapacity);
+    }
+
+    private void reset(int capacity) {
+        this.head = new HeadElement();
+        // Avoid using even-sized capacities, to get better key distribution.
+        this.elements = new Element[(2 * capacity) + 1];
+        this.size = 0;
+    }
+
+    int numSlots() {
+        return elements.length;
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java
new file mode 100644
index 0000000..3095717
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchResponse;
+import org.apache.kafka.common.utils.LogContext;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+
+import static org.apache.kafka.common.requests.FetchMetadata.INITIAL_EPOCH;
+import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * A unit test for FetchSessionHandler.
+ */
+public class FetchSessionHandlerTest {
+    @Rule
+    final public Timeout globalTimeout = Timeout.millis(120000);
+
+    private static final LogContext LOG_CONTEXT = new LogContext("[FetchSessionHandler]=");
+
+    private static final Logger log = LOG_CONTEXT.logger(FetchSessionHandler.class);
+
+    /**
+     * Create a set of TopicPartitions.  We use a TreeSet, in order to get a deterministic
+     * ordering for test purposes.
+     */
+    private final static Set<TopicPartition> toSet(TopicPartition... arr) {
+        TreeSet<TopicPartition> set = new TreeSet<>(new Comparator<TopicPartition>() {
+            @Override
+            public int compare(TopicPartition o1, TopicPartition o2) {
+                return o1.toString().compareTo(o2.toString());
+            }
+        });
+        set.addAll(Arrays.asList(arr));
+        return set;
+    }
+
+    @Test
+    public void testFindMissing() throws Exception {
+        TopicPartition foo0 = new TopicPartition("foo", 0);
+        TopicPartition foo1 = new TopicPartition("foo", 1);
+        TopicPartition bar0 = new TopicPartition("bar", 0);
+        TopicPartition bar1 = new TopicPartition("bar", 1);
+        TopicPartition baz0 = new TopicPartition("baz", 0);
+        TopicPartition baz1 = new TopicPartition("baz", 1);
+        assertEquals(toSet(), FetchSessionHandler.findMissing(toSet(foo0), toSet(foo0)));
+        assertEquals(toSet(foo0), FetchSessionHandler.findMissing(toSet(foo0), toSet(foo1)));
+        assertEquals(toSet(foo0, foo1),
+            FetchSessionHandler.findMissing(toSet(foo0, foo1), toSet(baz0)));
+        assertEquals(toSet(bar1, foo0, foo1),
+            FetchSessionHandler.findMissing(toSet(foo0, foo1, bar0, bar1),
+                toSet(bar0, baz0, baz1)));
+        assertEquals(toSet(),
+            FetchSessionHandler.findMissing(toSet(foo0, foo1, bar0, bar1, baz1),
+                toSet(foo0, foo1, bar0, bar1, baz0, baz1)));
+    }
+
+    private static final class ReqEntry {
+        final TopicPartition part;
+        final FetchRequest.PartitionData data;
+
+        ReqEntry(String topic, int partition, long fetchOffset, long logStartOffset, int maxBytes) {
+            this.part = new TopicPartition(topic, partition);
+            this.data = new FetchRequest.PartitionData(fetchOffset, logStartOffset, maxBytes);
+        }
+    }
+
+    private static LinkedHashMap<TopicPartition, FetchRequest.PartitionData> reqMap(ReqEntry... entries) {
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> map = new LinkedHashMap<>();
+        for (ReqEntry entry : entries) {
+            map.put(entry.part, entry.data);
+        }
+        return map;
+    }
+
+    private static void assertMapEquals(Map<TopicPartition, FetchRequest.PartitionData> expected,
+                                        Map<TopicPartition, FetchRequest.PartitionData> actual) {
+        Iterator<Map.Entry<TopicPartition, FetchRequest.PartitionData>> expectedIter =
+            expected.entrySet().iterator();
+        Iterator<Map.Entry<TopicPartition, FetchRequest.PartitionData>> actualIter =
+            actual.entrySet().iterator();
+        int i = 1;
+        while (expectedIter.hasNext()) {
+            Map.Entry<TopicPartition, FetchRequest.PartitionData> expectedEntry = expectedIter.next();
+            if (!actualIter.hasNext()) {
+                fail("Element " + i + " not found.");
+            }
+            Map.Entry<TopicPartition, FetchRequest.PartitionData> actuaLEntry = actualIter.next();
+            assertEquals("Element " + i + " had a different TopicPartition than expected.",
+                expectedEntry.getKey(), actuaLEntry.getKey());
+            assertEquals("Element " + i + " had different PartitionData than expected.",
+                expectedEntry.getValue(), actuaLEntry.getValue());
+            i++;
+        }
+        if (expectedIter.hasNext()) {
+            fail("Unexpected element " + i + " found.");
+        }
+    }
+
+    private static void assertMapsEqual(Map<TopicPartition, FetchRequest.PartitionData> expected,
+                                        Map<TopicPartition, FetchRequest.PartitionData>... actuals) {
+        for (Map<TopicPartition, FetchRequest.PartitionData> actual : actuals) {
+            assertMapEquals(expected, actual);
+        }
+    }
+
+    private static void assertListEquals(List<TopicPartition> expected, List<TopicPartition> actual) {
+        for (TopicPartition expectedPart : expected) {
+            if (!actual.contains(expectedPart)) {
+                fail("Failed to find expected partition " + expectedPart);
+            }
+        }
+        for (TopicPartition actualPart : actual) {
+            if (!expected.contains(actualPart)) {
+                fail("Found unexpected partition " + actualPart);
+            }
+        }
+    }
+
+    private static final class RespEntry {
+        final TopicPartition part;
+        final FetchResponse.PartitionData data;
+
+        RespEntry(String topic, int partition, long highWatermark, long lastStableOffset) {
+            this.part = new TopicPartition(topic, partition);
+            this.data = new FetchResponse.PartitionData(
+                Errors.NONE,
+                highWatermark,
+                lastStableOffset,
+                0,
+                null,
+                null);
+        }
+    }
+
+    private static LinkedHashMap<TopicPartition, FetchResponse.PartitionData> respMap(RespEntry... entries) {
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> map = new LinkedHashMap<>();
+        for (RespEntry entry : entries) {
+            map.put(entry.part, entry.data);
+        }
+        return map;
+    }
+
+    /**
+     * Test the handling of SESSIONLESS responses.
+     * Pre-KIP-227 brokers always supply this kind of response.
+     */
+    @Test
+    public void testSessionless() throws Exception {
+        FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
+        FetchSessionHandler.Builder builder = handler.newBuilder();
+        builder.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(0, 100, 200));
+        builder.add(new TopicPartition("foo", 1),
+            new FetchRequest.PartitionData(10, 110, 210));
+        FetchSessionHandler.FetchRequestData data = builder.build();
+        assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
+                               new ReqEntry("foo", 1, 10, 110, 210)),
+            data.toSend(), data.sessionPartitions());
+        assertEquals(INVALID_SESSION_ID, data.metadata().sessionId());
+        assertEquals(INITIAL_EPOCH, data.metadata().epoch());
+
+        FetchResponse resp = new FetchResponse(Errors.NONE,
+            respMap(new RespEntry("foo", 0, 0, 0),
+                    new RespEntry("foo", 1, 0, 0)),
+            0, INVALID_SESSION_ID);
+        handler.handleResponse(resp);
+
+        FetchSessionHandler.Builder builder2 = handler.newBuilder();
+        builder2.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(0, 100, 200));
+        FetchSessionHandler.FetchRequestData data2 = builder2.build();
+        assertEquals(INVALID_SESSION_ID, data2.metadata().sessionId());
+        assertEquals(INITIAL_EPOCH, data2.metadata().epoch());
+        assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)),
+            data.toSend(), data.sessionPartitions());
+    }
+
+    /**
+     * Test handling an incremental fetch session.
+     */
+    @Test
+    public void testIncrementals() throws Exception {
+        FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
+        FetchSessionHandler.Builder builder = handler.newBuilder();
+        builder.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(0, 100, 200));
+        builder.add(new TopicPartition("foo", 1),
+            new FetchRequest.PartitionData(10, 110, 210));
+        FetchSessionHandler.FetchRequestData data = builder.build();
+        assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
+            new ReqEntry("foo", 1, 10, 110, 210)),
+            data.toSend(), data.sessionPartitions());
+        assertEquals(INVALID_SESSION_ID, data.metadata().sessionId());
+        assertEquals(INITIAL_EPOCH, data.metadata().epoch());
+
+        FetchResponse resp = new FetchResponse(Errors.NONE,
+            respMap(new RespEntry("foo", 0, 10, 20),
+                    new RespEntry("foo", 1, 10, 20)),
+            0, 123);
+        handler.handleResponse(resp);
+
+        // Test an incremental fetch request which adds one partition and modifies another.
+        FetchSessionHandler.Builder builder2 = handler.newBuilder();
+        builder2.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(0, 100, 200));
+        builder2.add(new TopicPartition("foo", 1),
+            new FetchRequest.PartitionData(10, 120, 210));
+        builder2.add(new TopicPartition("bar", 0),
+            new FetchRequest.PartitionData(20, 200, 200));
+        FetchSessionHandler.FetchRequestData data2 = builder2.build();
+        assertFalse(data2.metadata().isFull());
+        assertMapEquals(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
+                new ReqEntry("foo", 1, 10, 120, 210),
+                new ReqEntry("bar", 0, 20, 200, 200)),
+            data2.sessionPartitions());
+        assertMapEquals(reqMap(new ReqEntry("bar", 0, 20, 200, 200),
+                new ReqEntry("foo", 1, 10, 120, 210)),
+            data2.toSend());
+
+        FetchResponse resp2 = new FetchResponse(Errors.NONE,
+            respMap(new RespEntry("foo", 1, 20, 20)),
+            0, 123);
+        handler.handleResponse(resp2);
+
+        // Skip building a new request.  Test that handling an invalid fetch session epoch response results
+        // in a request which closes the session.
+        FetchResponse resp3 = new FetchResponse(Errors.INVALID_FETCH_SESSION_EPOCH, respMap(),
+            0, INVALID_SESSION_ID);
+        handler.handleResponse(resp3);
+
+        FetchSessionHandler.Builder builder4 = handler.newBuilder();
+        builder4.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(0, 100, 200));
+        builder4.add(new TopicPartition("foo", 1),
+            new FetchRequest.PartitionData(10, 120, 210));
+        builder4.add(new TopicPartition("bar", 0),
+            new FetchRequest.PartitionData(20, 200, 200));
+        FetchSessionHandler.FetchRequestData data4 = builder4.build();
+        assertTrue(data4.metadata().isFull());
+        assertEquals(data2.metadata().sessionId(), data4.metadata().sessionId());
+        assertEquals(INITIAL_EPOCH, data4.metadata().epoch());
+        assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
+            new ReqEntry("foo", 1, 10, 120, 210),
+            new ReqEntry("bar", 0, 20, 200, 200)),
+            data4.sessionPartitions(), data4.toSend());
+    }
+
+    /**
+     * Test that calling FetchSessionHandler#Builder#build twice fails.
+     */
+    @Test
+    public void testDoubleBuild() throws Exception {
+        FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
+        FetchSessionHandler.Builder builder = handler.newBuilder();
+        builder.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(0, 100, 200));
+        builder.build();
+        try {
+            builder.build();
+            fail("Expected calling build twice to fail.");
+        } catch (Throwable t) {
+            // expected
+        }
+    }
+
+    @Test
+    public void testIncrementalPartitionRemoval() throws Exception {
+        FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
+        FetchSessionHandler.Builder builder = handler.newBuilder();
+        builder.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(0, 100, 200));
+        builder.add(new TopicPartition("foo", 1),
+            new FetchRequest.PartitionData(10, 110, 210));
+        builder.add(new TopicPartition("bar", 0),
+            new FetchRequest.PartitionData(20, 120, 220));
+        FetchSessionHandler.FetchRequestData data = builder.build();
+        assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
+            new ReqEntry("foo", 1, 10, 110, 210),
+            new ReqEntry("bar", 0, 20, 120, 220)),
+            data.toSend(), data.sessionPartitions());
+        assertTrue(data.metadata().isFull());
+
+        FetchResponse resp = new FetchResponse(Errors.NONE,
+            respMap(new RespEntry("foo", 0, 10, 20),
+                    new RespEntry("foo", 1, 10, 20),
+                    new RespEntry("bar", 0, 10, 20)),
+            0, 123);
+        handler.handleResponse(resp);
+
+        // Test an incremental fetch request which removes two partitions.
+        FetchSessionHandler.Builder builder2 = handler.newBuilder();
+        builder2.add(new TopicPartition("foo", 1),
+            new FetchRequest.PartitionData(10, 110, 210));
+        FetchSessionHandler.FetchRequestData data2 = builder2.build();
+        assertFalse(data2.metadata().isFull());
+        assertEquals(123, data2.metadata().sessionId());
+        assertEquals(1, data2.metadata().epoch());
+        assertMapEquals(reqMap(new ReqEntry("foo", 1, 10, 110, 210)),
+            data2.sessionPartitions());
+        assertMapEquals(reqMap(), data2.toSend());
+        ArrayList<TopicPartition> expectedToForget2 = new ArrayList<>();
+        expectedToForget2.add(new TopicPartition("foo", 0));
+        expectedToForget2.add(new TopicPartition("bar", 0));
+        assertListEquals(expectedToForget2, data2.toForget());
+
+        // A FETCH_SESSION_ID_NOT_FOUND response triggers us to close the session.
+        // The next request is a session establishing FULL request.
+        FetchResponse resp2 = new FetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND,
+            respMap(), 0, INVALID_SESSION_ID);
+        handler.handleResponse(resp2);
+        FetchSessionHandler.Builder builder3 = handler.newBuilder();
+        builder3.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(0, 100, 200));
+        FetchSessionHandler.FetchRequestData data3 = builder3.build();
+        assertTrue(data3.metadata().isFull());
+        assertEquals(INVALID_SESSION_ID, data3.metadata().sessionId());
+        assertEquals(INITIAL_EPOCH, data3.metadata().epoch());
+        assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)),
+            data3.sessionPartitions(), data3.toSend());
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
index a827168..d47124f 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
@@ -99,6 +99,7 @@ import java.util.regex.Pattern;
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonList;
 import static java.util.Collections.singletonMap;
+import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
@@ -1578,7 +1579,7 @@ public class KafkaConsumerTest {
             tpResponses.put(partition, new FetchResponse.PartitionData(Errors.NONE, 0, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L,
                     null, records));
         }
-        return new FetchResponse(tpResponses, 0);
+        return new FetchResponse(Errors.NONE, tpResponses, 0, INVALID_SESSION_ID);
     }
 
     private FetchResponse fetchResponse(TopicPartition partition, long fetchOffset, int count) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index a3ea793..a0205e7 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -103,6 +103,7 @@ import java.util.Map;
 import java.util.Set;
 
 import static java.util.Collections.singleton;
+import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -139,6 +140,7 @@ public class FetcherTest {
 
     private MemoryRecords records;
     private MemoryRecords nextRecords;
+    private MemoryRecords emptyRecords;
     private Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, metrics);
     private Metrics fetcherMetrics = new Metrics(time);
     private Fetcher<byte[], byte[]> fetcherNoAutoReset = createFetcher(subscriptionsNoAutoReset, fetcherMetrics);
@@ -158,6 +160,9 @@ public class FetcherTest {
         builder.append(0L, "key".getBytes(), "value-4".getBytes());
         builder.append(0L, "key".getBytes(), "value-5".getBytes());
         nextRecords = builder.build();
+
+        builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, 0L);
+        emptyRecords = builder.build();
     }
 
     @After
@@ -177,7 +182,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -219,7 +224,7 @@ public class FetcherTest {
 
         buffer.flip();
 
-        client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -242,7 +247,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -283,7 +288,7 @@ public class FetcherTest {
         subscriptions.assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
-        client.prepareResponse(matchesOffset(tp0, 1), fetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(0);
@@ -345,7 +350,7 @@ public class FetcherTest {
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
         consumerClient.poll(0);
 
         // the first fetchedRecords() should return the first valid message
@@ -383,7 +388,7 @@ public class FetcherTest {
         // Should not throw exception after the seek.
         fetcher.fetchedRecords();
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0));
         consumerClient.poll(0);
 
         List<ConsumerRecord<byte[], byte[]>> records = fetcher.fetchedRecords().get(tp0);
@@ -416,7 +421,7 @@ public class FetcherTest {
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
         consumerClient.poll(0);
 
         // the fetchedRecords() should always throw exception due to the bad batch.
@@ -447,7 +452,7 @@ public class FetcherTest {
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         try {
             fetcher.fetchedRecords();
@@ -480,7 +485,7 @@ public class FetcherTest {
         subscriptions.assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
-        client.prepareResponse(matchesOffset(tp0, 1), fetchResponse(tp0, memoryRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, memoryRecords, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(0);
@@ -510,8 +515,8 @@ public class FetcherTest {
         subscriptions.assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
-        client.prepareResponse(matchesOffset(tp0, 1), fetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
-        client.prepareResponse(matchesOffset(tp0, 4), fetchResponse(tp0, this.nextRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tp0, 4), fullFetchResponse(tp0, this.nextRecords, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(0);
@@ -551,7 +556,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 1);
 
         // Returns 3 records while `max.poll.records` is configured to 2
-        client.prepareResponse(matchesOffset(tp0, 1), fetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(0);
@@ -562,7 +567,7 @@ public class FetcherTest {
         assertEquals(2, records.get(1).offset());
 
         subscriptions.assignFromUser(singleton(tp1));
-        client.prepareResponse(matchesOffset(tp1, 4), fetchResponse(tp1, this.nextRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tp1, 4), fullFetchResponse(tp1, this.nextRecords, Errors.NONE, 100L, 0));
         subscriptions.seek(tp1, 4);
 
         assertEquals(1, fetcher.sendFetches());
@@ -594,7 +599,7 @@ public class FetcherTest {
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         consumerRecords = fetcher.fetchedRecords().get(tp0);
         assertEquals(3, consumerRecords.size());
@@ -654,7 +659,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
         MemoryRecords partialRecord = MemoryRecords.readableRecords(
             ByteBuffer.wrap(new byte[]{0, 0, 0, 0, 0, 0, 0, 0}));
-        client.prepareResponse(fetchResponse(tp0, partialRecord, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, partialRecord, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
     }
@@ -666,7 +671,7 @@ public class FetcherTest {
 
         // resize the limit of the buffer to pretend it is only fetch-size large
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.TOPIC_AUTHORIZATION_FAILED, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.TOPIC_AUTHORIZATION_FAILED, 100L, 0));
         consumerClient.poll(0);
         try {
             fetcher.fetchedRecords();
@@ -686,7 +691,7 @@ public class FetcherTest {
 
         // Now the rebalance happens and fetch positions are cleared
         subscriptions.assignFromSubscribed(singleton(tp0));
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
 
         // The active fetch should be ignored since its position is no longer valid
@@ -701,7 +706,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         subscriptions.pause(tp0);
 
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         assertNull(fetcher.fetchedRecords().get(tp0));
     }
@@ -722,7 +727,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0));
         consumerClient.poll(0);
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
@@ -734,7 +739,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.UNKNOWN_TOPIC_OR_PARTITION, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.UNKNOWN_TOPIC_OR_PARTITION, 100L, 0));
         consumerClient.poll(0);
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
@@ -746,7 +751,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         consumerClient.poll(0);
         assertEquals(0, fetcher.fetchedRecords().size());
         assertTrue(subscriptions.isOffsetResetNeeded(tp0));
@@ -761,7 +766,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         subscriptions.seek(tp0, 1);
         consumerClient.poll(0);
         assertEquals(0, fetcher.fetchedRecords().size());
@@ -775,7 +780,7 @@ public class FetcherTest {
         subscriptionsNoAutoReset.seek(tp0, 0);
 
         assertTrue(fetcherNoAutoReset.sendFetches() > 0);
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         consumerClient.poll(0);
         assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp0));
         subscriptionsNoAutoReset.seek(tp0, 2);
@@ -788,7 +793,7 @@ public class FetcherTest {
         subscriptionsNoAutoReset.seek(tp0, 0);
 
         fetcherNoAutoReset.sendFetches();
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         consumerClient.poll(0);
 
         assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp0));
@@ -818,7 +823,8 @@ public class FetcherTest {
             FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, records));
         partitions.put(tp0, new FetchResponse.PartitionData(Errors.OFFSET_OUT_OF_RANGE, 100,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY));
-        client.prepareResponse(new FetchResponse(new LinkedHashMap<>(partitions), 0));
+        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions),
+            0, INVALID_SESSION_ID));
         consumerClient.poll(0);
 
         List<ConsumerRecord<byte[], byte[]>> fetchedRecords = new ArrayList<>();
@@ -856,7 +862,7 @@ public class FetcherTest {
         Map<TopicPartition, FetchResponse.PartitionData> partitions = new HashMap<>();
         partitions.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 100,
             FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, records));
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
 
         assertEquals(2, fetcher.fetchedRecords().get(tp0).size());
@@ -867,7 +873,7 @@ public class FetcherTest {
         partitions = new HashMap<>();
         partitions.put(tp1, new FetchResponse.PartitionData(Errors.OFFSET_OUT_OF_RANGE, 100,
             FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY));
-        client.prepareResponse(new FetchResponse(new LinkedHashMap<>(partitions), 0));
+        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), 0, INVALID_SESSION_ID));
         consumerClient.poll(0);
         assertEquals(1, fetcher.fetchedRecords().get(tp0).size());
 
@@ -882,7 +888,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0), true);
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0), true);
         consumerClient.poll(0);
         assertEquals(0, fetcher.fetchedRecords().size());
 
@@ -1148,7 +1154,7 @@ public class FetcherTest {
             ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true, null);
             client.send(request, time.milliseconds());
             client.poll(1, time.milliseconds());
-            FetchResponse response = fetchResponse(tp0, nextRecords, Errors.NONE, i, throttleTimeMs);
+            FetchResponse response = fullFetchResponse(tp0, nextRecords, Errors.NONE, i, throttleTimeMs);
             buffer = response.serialize(ApiKeys.FETCH.latestVersion(), new ResponseHeader(request.correlationId()));
             selector.completeReceive(new NetworkReceive(node.idString(), buffer));
             client.poll(1, time.milliseconds());
@@ -1325,7 +1331,8 @@ public class FetcherTest {
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, MemoryRecords.EMPTY));
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(new FetchResponse(new LinkedHashMap<>(partitions), 0));
+        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions),
+            0, INVALID_SESSION_ID));
         consumerClient.poll(0);
         fetcher.fetchedRecords();
 
@@ -1364,7 +1371,8 @@ public class FetcherTest {
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null,
                 MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("val".getBytes()))));
 
-        client.prepareResponse(new FetchResponse(new LinkedHashMap<>(partitions), 0));
+        client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions),
+            0, INVALID_SESSION_ID));
         consumerClient.poll(0);
         fetcher.fetchedRecords();
 
@@ -1390,7 +1398,7 @@ public class FetcherTest {
         subscriptions.assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetcher.fetchedRecords();
@@ -1417,7 +1425,7 @@ public class FetcherTest {
     private Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchRecords(
             TopicPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, int throttleTime) {
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp, records, error, hw, lastStableOffset, throttleTime));
+        client.prepareResponse(fullFetchResponse(tp, records, error, hw, lastStableOffset, throttleTime));
         consumerClient.poll(0);
         return fetcher.fetchedRecords();
     }
@@ -1495,7 +1503,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
+        client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -1533,7 +1541,7 @@ public class FetcherTest {
                 assertEquals(IsolationLevel.READ_COMMITTED, request.isolationLevel());
                 return true;
             }
-        }, fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
+        }, fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
 
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
@@ -1604,7 +1612,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
+        client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -1651,7 +1659,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
+        client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -1695,7 +1703,7 @@ public class FetcherTest {
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
         abortedTransactions.add(new FetchResponse.AbortedTransaction(producerId, 0L));
 
-        client.prepareResponse(fetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer),
+        client.prepareResponse(fullFetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer),
                 abortedTransactions, Errors.NONE, 100L, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
@@ -1733,7 +1741,7 @@ public class FetcherTest {
         subscriptions.assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, compactedRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, compactedRecords, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -1768,7 +1776,7 @@ public class FetcherTest {
         subscriptions.assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponse(tp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -1829,7 +1837,7 @@ public class FetcherTest {
         abortedTransactions.add(new FetchResponse.AbortedTransaction(pid2, 6L));
         abortedTransactions.add(new FetchResponse.AbortedTransaction(pid1, 0L));
 
-        client.prepareResponse(fetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer),
+        client.prepareResponse(fullFetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer),
                 abortedTransactions, Errors.NONE, 100L, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
@@ -1867,7 +1875,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
+        client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -1900,7 +1908,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
+        client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0));
         consumerClient.poll(0);
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -1911,6 +1919,75 @@ public class FetcherTest {
         assertEquals(currentOffset, (long) subscriptions.position(tp0));
     }
 
+    @Test
+    public void testConsumingViaIncrementalFetchRequests() {
+        Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(time), 2);
+
+        List<ConsumerRecord<byte[], byte[]>> records;
+        subscriptions.assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1)));
+        subscriptions.seek(tp0, 0);
+        subscriptions.seek(tp1, 1);
+
+        // Fetch some records and establish an incremental fetch session.
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> partitions1 = new LinkedHashMap<>();
+        partitions1.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 2L,
+            2, 0L, null, this.records));
+        partitions1.put(tp1, new FetchResponse.PartitionData(Errors.NONE, 100L,
+            FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, emptyRecords));
+        FetchResponse resp1 = new FetchResponse(Errors.NONE, partitions1, 0, 123);
+        client.prepareResponse(resp1);
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+        consumerClient.poll(0);
+        assertTrue(fetcher.hasCompletedFetches());
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetcher.fetchedRecords();
+        assertFalse(fetchedRecords.containsKey(tp1));
+        records = fetchedRecords.get(tp0);
+        assertEquals(2, records.size());
+        assertEquals(3L, subscriptions.position(tp0).longValue());
+        assertEquals(1L, subscriptions.position(tp1).longValue());
+        assertEquals(1, records.get(0).offset());
+        assertEquals(2, records.get(1).offset());
+
+        // There is still a buffered record.
+        assertEquals(0, fetcher.sendFetches());
+        fetchedRecords = fetcher.fetchedRecords();
+        assertFalse(fetchedRecords.containsKey(tp1));
+        records = fetchedRecords.get(tp0);
+        assertEquals(1, records.size());
+        assertEquals(3, records.get(0).offset());
+        assertEquals(4L, subscriptions.position(tp0).longValue());
+
+        // The second response contains no new records.
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> partitions2 = new LinkedHashMap<>();
+        FetchResponse resp2 = new FetchResponse(Errors.NONE, partitions2, 0, 123);
+        client.prepareResponse(resp2);
+        assertEquals(1, fetcher.sendFetches());
+        consumerClient.poll(0);
+        fetchedRecords = fetcher.fetchedRecords();
+        assertTrue(fetchedRecords.isEmpty());
+        assertEquals(4L, subscriptions.position(tp0).longValue());
+        assertEquals(1L, subscriptions.position(tp1).longValue());
+
+        // The third response contains some new records for tp0.
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> partitions3 = new LinkedHashMap<>();
+        partitions3.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 100L,
+            4, 0L, null, this.nextRecords));
+        new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions1), 0, INVALID_SESSION_ID);
+        FetchResponse resp3 = new FetchResponse(Errors.NONE, partitions3, 0, 123);
+        client.prepareResponse(resp3);
+        assertEquals(1, fetcher.sendFetches());
+        consumerClient.poll(0);
+        fetchedRecords = fetcher.fetchedRecords();
+        assertFalse(fetchedRecords.containsKey(tp1));
+        records = fetchedRecords.get(tp0);
+        assertEquals(2, records.size());
+        assertEquals(6L, subscriptions.position(tp0).longValue());
+        assertEquals(1L, subscriptions.position(tp1).longValue());
+        assertEquals(4, records.get(0).offset());
+        assertEquals(5, records.get(1).offset());
+    }
+
     private int appendTransactionalRecords(ByteBuffer buffer, long pid, long baseOffset, int baseSequence, SimpleRecord... records) {
         MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
                 TimestampType.CREATE_TIME, baseOffset, time.milliseconds(), pid, (short) 0, baseSequence, true,
@@ -2033,7 +2110,7 @@ public class FetcherTest {
         return new ListOffsetResponse(allPartitionData);
     }
 
-    private FetchResponse fetchResponseWithAbortedTransactions(MemoryRecords records,
+    private FetchResponse fullFetchResponseWithAbortedTransactions(MemoryRecords records,
                                                                List<FetchResponse.AbortedTransaction> abortedTransactions,
                                                                Errors error,
                                                                long lastStableOffset,
@@ -2041,18 +2118,18 @@ public class FetcherTest {
                                                                int throttleTime) {
         Map<TopicPartition, FetchResponse.PartitionData> partitions = Collections.singletonMap(tp0,
                 new FetchResponse.PartitionData(error, hw, lastStableOffset, 0L, abortedTransactions, records));
-        return new FetchResponse(new LinkedHashMap<>(partitions), throttleTime);
+        return new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), throttleTime, INVALID_SESSION_ID);
     }
 
-    private FetchResponse fetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) {
-        return fetchResponse(tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime);
+    private FetchResponse fullFetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) {
+        return fullFetchResponse(tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime);
     }
 
-    private FetchResponse fetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw,
+    private FetchResponse fullFetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw,
                                         long lastStableOffset, int throttleTime) {
         Map<TopicPartition, FetchResponse.PartitionData> partitions = Collections.singletonMap(tp,
                 new FetchResponse.PartitionData(error, hw, lastStableOffset, 0L, null, records));
-        return new FetchResponse(new LinkedHashMap<>(partitions), throttleTime);
+        return new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), throttleTime, INVALID_SESSION_ID);
     }
 
     private MetadataResponse newMetadataResponse(String topic, Errors error) {
diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
index 0f7429e..bdbd106 100644
--- a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
@@ -75,6 +75,7 @@ import java.util.Set;
 
 import static java.util.Arrays.asList;
 import static java.util.Collections.singletonList;
+import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
 import static org.apache.kafka.test.TestUtils.toBuffer;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -97,6 +98,13 @@ public class RequestResponseTest {
         checkErrorResponse(createControlledShutdownRequest(0), new UnknownServerException());
         checkRequest(createFetchRequest(4));
         checkResponse(createFetchResponse(), 4);
+        List<TopicPartition> toForgetTopics = new ArrayList<>();
+        toForgetTopics.add(new TopicPartition("foo", 0));
+        toForgetTopics.add(new TopicPartition("foo", 2));
+        toForgetTopics.add(new TopicPartition("bar", 0));
+        checkRequest(createFetchRequest(7, new FetchMetadata(123, 456), toForgetTopics));
+        checkResponse(createFetchResponse(123), 7);
+        checkResponse(createFetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND, 123), 7);
         checkErrorResponse(createFetchRequest(4), new UnknownServerException());
         checkRequest(createHeartBeatRequest());
         checkErrorResponse(createHeartBeatRequest(), new UnknownServerException());
@@ -459,8 +467,8 @@ public class RequestResponseTest {
         responseData.put(new TopicPartition("test", 0), new FetchResponse.PartitionData(Errors.NONE, 1000000,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, records));
 
-        FetchResponse v0Response = new FetchResponse(responseData, 0);
-        FetchResponse v1Response = new FetchResponse(responseData, 10);
+        FetchResponse v0Response = new FetchResponse(Errors.NONE, responseData, 0, INVALID_SESSION_ID);
+        FetchResponse v1Response = new FetchResponse(Errors.NONE, responseData, 10, INVALID_SESSION_ID);
         assertEquals("Throttle time must be zero", 0, v0Response.throttleTimeMs());
         assertEquals("Throttle time must be 10", 10, v1Response.throttleTimeMs());
         assertEquals("Should use schema version 0", ApiKeys.FETCH.responseSchema((short) 0),
@@ -488,15 +496,22 @@ public class RequestResponseTest {
         responseData.put(new TopicPartition("foo", 0), new FetchResponse.PartitionData(Errors.NONE, 70000,
                 6, FetchResponse.INVALID_LOG_START_OFFSET, Collections.<FetchResponse.AbortedTransaction>emptyList(), records));
 
-        FetchResponse response = new FetchResponse(responseData, 10);
+        FetchResponse response = new FetchResponse(Errors.NONE, responseData, 10, INVALID_SESSION_ID);
         FetchResponse deserialized = FetchResponse.parse(toBuffer(response.toStruct((short) 4)), (short) 4);
         assertEquals(responseData, deserialized.responseData());
     }
 
     @Test
-    public void verifyFetchResponseFullWrite() throws Exception {
-        FetchResponse fetchResponse = createFetchResponse();
-        short apiVersion = ApiKeys.FETCH.latestVersion();
+    public void verifyFetchResponseFullWrites() throws Exception {
+        verifyFetchResponseFullWrite(ApiKeys.FETCH.latestVersion(), createFetchResponse(123));
+        verifyFetchResponseFullWrite(ApiKeys.FETCH.latestVersion(),
+            createFetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND, 123));
+        for (short version = 0; version <= ApiKeys.FETCH.latestVersion(); version++) {
+            verifyFetchResponseFullWrite(version, createFetchResponse());
+        }
+    }
+
+    private void verifyFetchResponseFullWrite(short apiVersion, FetchResponse fetchResponse) throws Exception {
         int correlationId = 15;
 
         Send send = fetchResponse.toSend("1", new ResponseHeader(correlationId), apiVersion);
@@ -559,6 +574,19 @@ public class RequestResponseTest {
     }
 
     @Test
+    public void testFetchRequestWithMetadata() throws Exception {
+        FetchRequest request = createFetchRequest(4, IsolationLevel.READ_COMMITTED);
+        Struct struct = request.toStruct();
+        FetchRequest deserialized = (FetchRequest) deserialize(request, struct, request.version());
+        assertEquals(request.isolationLevel(), deserialized.isolationLevel());
+
+        request = createFetchRequest(4, IsolationLevel.READ_UNCOMMITTED);
+        struct = request.toStruct();
+        deserialized = (FetchRequest) deserialize(request, struct, request.version());
+        assertEquals(request.isolationLevel(), deserialized.isolationLevel());
+    }
+
+    @Test
     public void testJoinGroupRequestVersion0RebalanceTimeout() throws Exception {
         final short version = 0;
         JoinGroupRequest jgr = createJoinGroupRequest(version);
@@ -589,11 +617,20 @@ public class RequestResponseTest {
         return new FindCoordinatorResponse(Errors.NONE, new Node(10, "host1", 2014));
     }
 
+    private FetchRequest createFetchRequest(int version, FetchMetadata metadata, List<TopicPartition> toForget) {
+        LinkedHashMap<TopicPartition, FetchRequest.PartitionData> fetchData = new LinkedHashMap<>();
+        fetchData.put(new TopicPartition("test1", 0), new FetchRequest.PartitionData(100, 0L, 1000000));
+        fetchData.put(new TopicPartition("test2", 0), new FetchRequest.PartitionData(200, 0L, 1000000));
+        return FetchRequest.Builder.forConsumer(100, 100000, fetchData).
+            metadata(metadata).setMaxBytes(1000).toForget(toForget).build((short) version);
+    }
+
     private FetchRequest createFetchRequest(int version, IsolationLevel isolationLevel) {
         LinkedHashMap<TopicPartition, FetchRequest.PartitionData> fetchData = new LinkedHashMap<>();
         fetchData.put(new TopicPartition("test1", 0), new FetchRequest.PartitionData(100, 0L, 1000000));
         fetchData.put(new TopicPartition("test2", 0), new FetchRequest.PartitionData(200, 0L, 1000000));
-        return FetchRequest.Builder.forConsumer(100, 100000, fetchData, isolationLevel).setMaxBytes(1000).build((short) version);
+        return FetchRequest.Builder.forConsumer(100, 100000, fetchData).
+            isolationLevel(isolationLevel).setMaxBytes(1000).build((short) version);
     }
 
     private FetchRequest createFetchRequest(int version) {
@@ -603,6 +640,23 @@ public class RequestResponseTest {
         return FetchRequest.Builder.forConsumer(100, 100000, fetchData).setMaxBytes(1000).build((short) version);
     }
 
+    private FetchResponse createFetchResponse(Errors error, int sessionId) {
+        return new FetchResponse(error, new LinkedHashMap<TopicPartition, FetchResponse.PartitionData>(),
+            25, sessionId);
+    }
+
+    private FetchResponse createFetchResponse(int sessionId) {
+        LinkedHashMap<TopicPartition, FetchResponse.PartitionData> responseData = new LinkedHashMap<>();
+        MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("blah".getBytes()));
+        responseData.put(new TopicPartition("test", 0), new FetchResponse.PartitionData(Errors.NONE,
+            1000000, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, records));
+        List<FetchResponse.AbortedTransaction> abortedTransactions = Collections.singletonList(
+            new FetchResponse.AbortedTransaction(234L, 999L));
+        responseData.put(new TopicPartition("test", 1), new FetchResponse.PartitionData(Errors.NONE,
+            1000000, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, abortedTransactions, MemoryRecords.EMPTY));
+        return new FetchResponse(Errors.NONE, responseData, 25, sessionId);
+    }
+
     private FetchResponse createFetchResponse() {
         LinkedHashMap<TopicPartition, FetchResponse.PartitionData> responseData = new LinkedHashMap<>();
         MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("blah".getBytes()));
@@ -614,7 +668,7 @@ public class RequestResponseTest {
         responseData.put(new TopicPartition("test", 1), new FetchResponse.PartitionData(Errors.NONE,
                 1000000, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, abortedTransactions, MemoryRecords.EMPTY));
 
-        return new FetchResponse(responseData, 25);
+        return new FetchResponse(Errors.NONE, responseData, 25, INVALID_SESSION_ID);
     }
 
     private HeartbeatRequest createHeartBeatRequest() {
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashSetTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashSetTest.java
new file mode 100644
index 0000000..20084a2
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashSetTest.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.utils;
+
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Random;
+
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * A unit test for ImplicitLinkedHashSet.
+ */
+public class ImplicitLinkedHashSetTest {
+    @Rule
+    final public Timeout globalTimeout = Timeout.millis(120000);
+
+    private final static class TestElement implements ImplicitLinkedHashSet.Element {
+        private int prev = ImplicitLinkedHashSet.INVALID_INDEX;
+        private int next = ImplicitLinkedHashSet.INVALID_INDEX;
+        private final int val;
+
+        TestElement(int val) {
+            this.val = val;
+        }
+
+        @Override
+        public int prev() {
+            return prev;
+        }
+
+        @Override
+        public void setPrev(int prev) {
+            this.prev = prev;
+        }
+
+        @Override
+        public int next() {
+            return next;
+        }
+
+        @Override
+        public void setNext(int next) {
+            this.next = next;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if ((o == null) || (o.getClass() != TestElement.class)) return false;
+            TestElement that = (TestElement) o;
+            return val == that.val;
+        }
+
+        @Override
+        public String toString() {
+            return "TestElement(" + val + ")";
+        }
+
+        @Override
+        public int hashCode() {
+            return val;
+        }
+    }
+
+    @Test
+    public void testInsertDelete() throws Exception {
+        ImplicitLinkedHashSet<TestElement> set = new ImplicitLinkedHashSet<>(100);
+        assertTrue(set.add(new TestElement(1)));
+        TestElement second = new TestElement(2);
+        assertTrue(set.add(second));
+        assertTrue(set.add(new TestElement(3)));
+        assertFalse(set.add(new TestElement(3)));
+        assertEquals(3, set.size());
+        assertTrue(set.contains(new TestElement(1)));
+        assertFalse(set.contains(new TestElement(4)));
+        TestElement secondAgain = set.find(new TestElement(2));
+        assertTrue(second == secondAgain);
+        assertTrue(set.remove(new TestElement(1)));
+        assertFalse(set.remove(new TestElement(1)));
+        assertEquals(2, set.size());
+        set.clear();
+        assertEquals(0, set.size());
+    }
+
+    private static void expectTraversal(Iterator<TestElement> iterator, Integer... sequence) {
+        int i = 0;
+        while (iterator.hasNext()) {
+            TestElement element = iterator.next();
+            Assert.assertTrue("Iterator yieled " + (i + 1) + " elements, but only " +
+                sequence.length + " were expected.", i < sequence.length);
+            Assert.assertEquals("Iterator value number " + (i + 1) + " was incorrect.",
+                sequence[i].intValue(), element.val);
+            i = i + 1;
+        }
+        Assert.assertTrue("Iterator yieled " + (i + 1) + " elements, but " +
+            sequence.length + " were expected.", i == sequence.length);
+    }
+
+    private static void expectTraversal(Iterator<TestElement> iter,
+                                        Iterator<Integer> expectedIter) {
+        int i = 0;
+        while (iter.hasNext()) {
+            TestElement element = iter.next();
+            Assert.assertTrue("Iterator yieled " + (i + 1) + " elements, but only " +
+                i + " were expected.", expectedIter.hasNext());
+            Integer expected = expectedIter.next();
+            Assert.assertEquals("Iterator value number " + (i + 1) + " was incorrect.",
+                expected.intValue(), element.val);
+            i = i + 1;
+        }
+        Assert.assertFalse("Iterator yieled " + i + " elements, but at least " +
+            (i + 1) + " were expected.", expectedIter.hasNext());
+    }
+
+    @Test
+    public void testTraversal() throws Exception {
+        ImplicitLinkedHashSet<TestElement> set = new ImplicitLinkedHashSet<>(100);
+        expectTraversal(set.iterator());
+        assertTrue(set.add(new TestElement(2)));
+        expectTraversal(set.iterator(), 2);
+        assertTrue(set.add(new TestElement(1)));
+        expectTraversal(set.iterator(), 2, 1);
+        assertTrue(set.add(new TestElement(100)));
+        expectTraversal(set.iterator(), 2, 1, 100);
+        assertTrue(set.remove(new TestElement(1)));
+        expectTraversal(set.iterator(), 2, 100);
+        assertTrue(set.add(new TestElement(1)));
+        expectTraversal(set.iterator(), 2, 100, 1);
+        Iterator<TestElement> iter = set.iterator();
+        iter.next();
+        iter.next();
+        iter.remove();
+        iter.next();
+        assertFalse(iter.hasNext());
+        expectTraversal(set.iterator(), 2, 1);
+        List<TestElement> list = new ArrayList<>();
+        list.add(new TestElement(1));
+        list.add(new TestElement(2));
+        assertTrue(set.removeAll(list));
+        assertFalse(set.removeAll(list));
+        expectTraversal(set.iterator());
+        assertEquals(0, set.size());
+        assertTrue(set.isEmpty());
+    }
+
+    @Test
+    public void testCollisions() throws Exception {
+        ImplicitLinkedHashSet<TestElement> set = new ImplicitLinkedHashSet<>(5);
+        assertEquals(11, set.numSlots());
+        assertTrue(set.add(new TestElement(11)));
+        assertTrue(set.add(new TestElement(0)));
+        assertTrue(set.add(new TestElement(22)));
+        assertTrue(set.add(new TestElement(33)));
+        assertEquals(11, set.numSlots());
+        expectTraversal(set.iterator(), 11, 0, 22, 33);
+        assertTrue(set.remove(new TestElement(22)));
+        expectTraversal(set.iterator(), 11, 0, 33);
+        assertEquals(3, set.size());
+        assertFalse(set.isEmpty());
+    }
+
+    @Test
+    public void testEnlargement() throws Exception {
+        ImplicitLinkedHashSet<TestElement> set = new ImplicitLinkedHashSet<>(5);
+        assertEquals(11, set.numSlots());
+        for (int i = 0; i < 6; i++) {
+            assertTrue(set.add(new TestElement(i)));
+        }
+        assertEquals(23, set.numSlots());
+        assertEquals(6, set.size());
+        expectTraversal(set.iterator(), 0, 1, 2, 3, 4, 5);
+        for (int i = 0; i < 6; i++) {
+            assertTrue("Failed to find element " + i, set.contains(new TestElement(i)));
+        }
+        set.remove(new TestElement(3));
+        assertEquals(23, set.numSlots());
+        assertEquals(5, set.size());
+        expectTraversal(set.iterator(), 0, 1, 2, 4, 5);
+    }
+
+    @Test
+    public void testManyInsertsAndDeletes() throws Exception {
+        Random random = new Random(123);
+        LinkedHashSet<Integer> existing = new LinkedHashSet<>();
+        ImplicitLinkedHashSet<TestElement> set = new ImplicitLinkedHashSet<>();
+        for (int i = 0; i < 100; i++) {
+            addRandomElement(random, existing, set);
+            addRandomElement(random, existing, set);
+            addRandomElement(random, existing, set);
+            removeRandomElement(random, existing, set);
+            expectTraversal(set.iterator(), existing.iterator());
+        }
+    }
+
+    private void addRandomElement(Random random, LinkedHashSet<Integer> existing,
+                                  ImplicitLinkedHashSet<TestElement> set) {
+        int next;
+        do {
+            next = random.nextInt();
+        } while (existing.contains(next));
+        existing.add(next);
+        set.add(new TestElement(next));
+    }
+
+    private void removeRandomElement(Random random, LinkedHashSet<Integer> existing,
+                                     ImplicitLinkedHashSet<TestElement> set) {
+        int removeIdx = random.nextInt(existing.size());
+        Iterator<Integer> iter = existing.iterator();
+        Integer element = null;
+        for (int i = 0; i <= removeIdx; i++) {
+            element = iter.next();
+        }
+        existing.remove(new TestElement(element));
+    }
+}
diff --git a/core/src/main/scala/kafka/api/ApiVersion.scala b/core/src/main/scala/kafka/api/ApiVersion.scala
index f95fb89..b8329c1 100644
--- a/core/src/main/scala/kafka/api/ApiVersion.scala
+++ b/core/src/main/scala/kafka/api/ApiVersion.scala
@@ -73,8 +73,10 @@ object ApiVersion {
     // Introduced LeaderAndIsrRequest V1, UpdateMetadataRequest V4 and FetchRequest V6 via KIP-112
     "1.0-IV0" -> KAFKA_1_0_IV0,
     "1.0" -> KAFKA_1_0_IV0,
-    // Introduced DeleteGroupsRequest V0 via KIP-229
-    "1.1-IV0" -> KAFKA_1_1_IV0
+    // Introduced DeleteGroupsRequest V0 via KIP-229, plus KIP-227 incremental fetch requests,
+    // and KafkaStorageException for fetch requests.
+    "1.1-IV0" -> KAFKA_1_1_IV0,
+    "1.1" -> KAFKA_1_1_IV0
   )
 
   private val versionPattern = "\\.".r
@@ -191,4 +193,3 @@ case object KAFKA_1_1_IV0 extends ApiVersion {
   val messageFormatVersion: Byte = RecordBatch.MAGIC_VALUE_V2
   val id: Int = 14
 }
-
diff --git a/core/src/main/scala/kafka/server/FetchSession.scala b/core/src/main/scala/kafka/server/FetchSession.scala
new file mode 100644
index 0000000..0a825f1
--- /dev/null
+++ b/core/src/main/scala/kafka/server/FetchSession.scala
@@ -0,0 +1,720 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.server
+
+import java.util
+import java.util.concurrent.{ThreadLocalRandom, TimeUnit}
+
+import com.yammer.metrics.core.Gauge
+import kafka.metrics.KafkaMetricsGroup
+import kafka.utils.Logging
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INITIAL_EPOCH, INVALID_SESSION_ID}
+import org.apache.kafka.common.requests.{FetchRequest, FetchResponse}
+import org.apache.kafka.common.requests.{FetchMetadata => JFetchMetadata}
+import org.apache.kafka.common.utils.{ImplicitLinkedHashSet, Time, Utils}
+
+import scala.math.Ordered.orderingToOrdered
+import scala.collection.{mutable, _}
+import scala.collection.JavaConverters._
+
+object FetchSession {
+  type REQ_MAP = util.Map[TopicPartition, FetchRequest.PartitionData]
+  type RESP_MAP = util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
+  type CACHE_MAP = ImplicitLinkedHashSet[CachedPartition]
+
+  val NUM_INCREMENTAL_FETCH_SESSISONS = "NumIncrementalFetchSessions"
+  val NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED = "NumIncrementalFetchPartitionsCached"
+  val INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC = "IncrementalFetchSessionEvictionsPerSec"
+  val EVICTIONS = "evictions"
+
+  def partitionsToLogString(partitions: util.Collection[TopicPartition], traceEnabled: Boolean): String = {
+    if (traceEnabled) {
+      "(" + Utils.join(partitions, ", ") + ")"
+    } else {
+      s"${partitions.size} partition(s)"
+    }
+  }
+}
+
+/**
+  * A cached partition.
+  *
+  * The broker maintains a set of these objects for each incremental fetch session.
+  * When an incremental fetch request is made, any partitions which are not explicitly
+  * enumerated in the fetch request are loaded from the cache.  Similarly, when an
+  * incremental fetch response is being prepared, any partitions that have not changed
+  * are left out of the response.
+  *
+  * We store many of these objects, so it is important for them to be memory-efficient.
+  * That is why we store topic and partition separately rather than storing a TopicPartition
+  * object.  The TP object takes up more memory because it is a separate JVM object, and
+  * because it stores the cached hash code in memory.
+  *
+  * Note that fetcherLogStartOffset is the LSO of the follower performing the fetch, whereas
+  * localLogStartOffset is the log start offset of the partition on this broker.
+  */
+class CachedPartition(val topic: String,
+                      val partition: Int,
+                      var maxBytes: Int,
+                      var fetchOffset: Long,
+                      var highWatermark: Long,
+                      var fetcherLogStartOffset: Long,
+                      var localLogStartOffset: Long)
+    extends ImplicitLinkedHashSet.Element {
+
+  var cachedNext: Int = ImplicitLinkedHashSet.INVALID_INDEX
+  var cachedPrev: Int = ImplicitLinkedHashSet.INVALID_INDEX
+
+  override def next = cachedNext
+  override def setNext(next: Int) = this.cachedNext = next
+  override def prev = cachedPrev
+  override def setPrev(prev: Int) = this.cachedPrev = prev
+
+  def this(topic: String, partition: Int) =
+    this(topic, partition, -1, -1, -1, -1, -1)
+
+  def this(part: TopicPartition) =
+    this(part.topic(), part.partition())
+
+  def this(part: TopicPartition, reqData: FetchRequest.PartitionData) =
+    this(part.topic(), part.partition(),
+      reqData.maxBytes, reqData.fetchOffset, -1,
+      reqData.logStartOffset, -1)
+
+  def this(part: TopicPartition, reqData: FetchRequest.PartitionData,
+           respData: FetchResponse.PartitionData) =
+    this(part.topic(), part.partition(),
+      reqData.maxBytes, reqData.fetchOffset, respData.highWatermark,
+      reqData.logStartOffset, respData.logStartOffset)
+
+  def topicPartition() = new TopicPartition(topic, partition)
+
+  def reqData() = new FetchRequest.PartitionData(fetchOffset, fetcherLogStartOffset, maxBytes)
+
+  def updateRequestParams(reqData: FetchRequest.PartitionData): Unit = {
+    // Update our cached request parameters.
+    maxBytes = reqData.maxBytes
+    fetchOffset = reqData.fetchOffset
+    fetcherLogStartOffset = reqData.logStartOffset
+  }
+
+  /**
+    * Update this CachedPartition with new request and response data.
+    *
+    * This function should be called while holding the appropriate session
+    * lock.
+    *
+    * @return True if this partition should be included in the FetchResponse
+    *         we send back to the fetcher; false if it can be omitted.
+    */
+  def updateResponseData(respData: FetchResponse.PartitionData): Boolean = {
+    // Check the response data.
+    var mustRespond = false
+    if ((respData.records != null) && (respData.records.sizeInBytes() > 0)) {
+      // Partitions with new data are always included in the response.
+      mustRespond = true
+    }
+    if (highWatermark != respData.highWatermark) {
+      mustRespond = true
+      highWatermark = respData.highWatermark
+    }
+    if (localLogStartOffset != respData.logStartOffset) {
+      mustRespond = true
+      localLogStartOffset = respData.logStartOffset
+    }
+    if (respData.error.code() != 0) {
+      // Partitions with errors are always included in the response.
+      // We also set the cached highWatermark to an invalid offset, -1.
+      // This ensures that when the error goes away, we re-send the partition.
+      highWatermark = -1
+      mustRespond = true
+    }
+    mustRespond
+  }
+
+  override def hashCode() = (31 * partition) + topic.hashCode
+
+  def canEqual(that: Any) = that.isInstanceOf[CachedPartition]
+
+  override def equals(that: Any): Boolean =
+    that match {
+      case that: CachedPartition => that.canEqual(this) &&
+        this.topic.equals(that.topic) &&
+        this.partition.equals(that.partition)
+      case _ => false
+    }
+
+  override def toString() = synchronized {
+    "CachedPartition(topic=" + topic +
+      ", partition=" + partition +
+      ", maxBytes=" + maxBytes +
+      ", fetchOffset=" + fetchOffset +
+      ", highWatermark=" + highWatermark +
+      ", fetcherLogStartOffset=" + fetcherLogStartOffset +
+      ", localLogStartOffset=" + localLogStartOffset  +
+        ")"
+  }
+}
+
+/**
+  * The fetch session.
+  *
+  * Each fetch session is protected by its own lock, which must be taken before mutable
+  * fields are read or modified.  This includes modification of the session partition map.
+  *
+  * @param id           The unique fetch session ID.
+  * @param privileged   True if this session is privileged.  Sessions crated by followers
+  *                     are privileged; sesssion created by consumers are not.
+  * @param partitionMap The CachedPartitionMap.
+  * @param creationMs   The time in milliseconds when this session was created.
+  * @param lastUsedMs   The last used time in milliseconds.  This should only be updated by
+  *                     FetchSessionCache#touch.
+  * @param epoch        The fetch session sequence number.
+  */
+case class FetchSession(val id: Int,
+                        val privileged: Boolean,
+                        val partitionMap: FetchSession.CACHE_MAP,
+                        val creationMs: Long,
+                        var lastUsedMs: Long,
+                        var epoch: Int) {
+  // This is used by the FetchSessionCache to store the last known size of this session.
+  // If this is -1, the Session is not in the cache.
+  var cachedSize = -1
+
+  def size(): Int = synchronized {
+    partitionMap.size()
+  }
+
+  def isEmpty(): Boolean = synchronized {
+    partitionMap.isEmpty
+  }
+
+  def lastUsedKey(): LastUsedKey = synchronized {
+    LastUsedKey(lastUsedMs, id)
+  }
+
+  def evictableKey(): EvictableKey = synchronized {
+    EvictableKey(privileged, cachedSize, id)
+  }
+
+  def metadata(): JFetchMetadata = synchronized { new JFetchMetadata(id, epoch) }
+
+  def getFetchOffset(topicPartition: TopicPartition): Option[Long] = synchronized {
+    Option(partitionMap.find(new CachedPartition(topicPartition))).map(_.fetchOffset)
+  }
+
+  type TL = util.ArrayList[TopicPartition]
+
+  // Update the cached partition data based on the request.
+  def update(fetchData: FetchSession.REQ_MAP,
+             toForget: util.List[TopicPartition],
+             reqMetadata: JFetchMetadata): (TL, TL, TL) = synchronized {
+    val added = new TL
+    val updated = new TL
+    val removed = new TL
+    fetchData.entrySet().iterator().asScala.foreach(entry => {
+      val topicPart = entry.getKey
+      val reqData = entry.getValue
+      val newCachedPart = new CachedPartition(topicPart, reqData)
+      val cachedPart = partitionMap.find(newCachedPart)
+      if (cachedPart == null) {
+        partitionMap.mustAdd(newCachedPart)
+        added.add(topicPart)
+      } else {
+        cachedPart.updateRequestParams(reqData)
+        updated.add(topicPart)
+      }
+    })
+    toForget.iterator().asScala.foreach(p => {
+      if (partitionMap.remove(new CachedPartition(p.topic(), p.partition()))) {
+        removed.add(p)
+      }
+    })
+    (added, updated, removed)
+  }
+
+  override def toString(): String = synchronized {
+    "FetchSession(id=" + id +
+      ", privileged=" + privileged +
+      ", partitionMap.size=" + partitionMap.size() +
+      ", creationMs=" + creationMs +
+      ", creationMs=" + lastUsedMs +
+      ", epoch=" + epoch + ")"
+  }
+}
+
+trait FetchContext extends Logging {
+  /**
+    * Get the fetch offset for a given partition.
+    */
+  def getFetchOffset(part: TopicPartition): Option[Long]
+
+  /**
+    * Apply a function to each partition in the fetch request.
+    */
+  def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit
+
+  /**
+    * Updates the fetch context with new partition information.  Generates response data.
+    * The response data may require subsequent down-conversion.
+    */
+  def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse
+
+  def partitionsToLogString(partitions: util.Collection[TopicPartition]): String =
+    FetchSession.partitionsToLogString(partitions, isTraceEnabled)
+}
+
+/**
+  * The fetch context for a fetch request that had a session error.
+  */
+class SessionErrorContext(val error: Errors,
+                          val reqMetadata: JFetchMetadata) extends FetchContext {
+  override def getFetchOffset(part: TopicPartition): Option[Long] = None
+
+  override def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit = {}
+
+  // Because of the fetch session error, we don't know what partitions were supposed to be in this request.
+  override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = {
+    debug(s"Session error fetch context returning $error")
+    new FetchResponse(error, new FetchSession.RESP_MAP, 0, INVALID_SESSION_ID)
+  }
+}
+
+/**
+  * The fetch context for a sessionless fetch request.
+  *
+  * @param fetchData          The partition data from the fetch request.
+  */
+class SessionlessFetchContext(val fetchData: util.Map[TopicPartition, FetchRequest.PartitionData]) extends FetchContext {
+  override def getFetchOffset(part: TopicPartition): Option[Long] =
+    Option(fetchData.get(part)).map(_.fetchOffset)
+
+  override def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit = {
+    fetchData.entrySet().asScala.foreach(entry => fun(entry.getKey, entry.getValue))
+  }
+
+  override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = {
+    debug(s"Sessionless fetch context returning ${partitionsToLogString(updates.keySet())}")
+    new FetchResponse(Errors.NONE, updates, 0, INVALID_SESSION_ID)
+  }
+}
+
+/**
+  * The fetch context for a full fetch request.
+  *
+  * @param time               The clock to use.
+  * @param cache              The fetch session cache.
+  * @param reqMetadata        The request metadata.
+  * @param fetchData          The partition data from the fetch request.
+  * @param isFromFollower     True if this fetch request came from a follower.
+  */
+class FullFetchContext(private val time: Time,
+                       private val cache: FetchSessionCache,
+                       private val reqMetadata: JFetchMetadata,
+                       private val fetchData: util.Map[TopicPartition, FetchRequest.PartitionData],
+                       private val isFromFollower: Boolean) extends FetchContext {
+  override def getFetchOffset(part: TopicPartition): Option[Long] =
+    Option(fetchData.get(part)).map(_.fetchOffset)
+
+  override def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit = {
+    fetchData.entrySet().asScala.foreach(entry => fun(entry.getKey, entry.getValue))
+  }
+
+  override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = {
+    def createNewSession(): FetchSession.CACHE_MAP = {
+      val cachedPartitions = new FetchSession.CACHE_MAP(updates.size())
+      updates.entrySet().asScala.foreach(entry => {
+        val part = entry.getKey
+        val respData = entry.getValue
+        val reqData = fetchData.get(part)
+        cachedPartitions.mustAdd(new CachedPartition(part, reqData, respData))
+      })
+      cachedPartitions
+    }
+    val responseSessionId = cache.maybeCreateSession(time.milliseconds(), isFromFollower,
+        updates.size(), createNewSession)
+    debug(s"Full fetch context with session id $responseSessionId returning " +
+      s"${partitionsToLogString(updates.keySet())}")
+    new FetchResponse(Errors.NONE, updates, 0, responseSessionId)
+  }
+}
+
+/**
+  * The fetch context for an incremental fetch request.
+  *
+  * @param time         The clock to use.
+  * @param reqMetadata  The request metadata.
+  * @param session      The incremental fetch request session.
+  */
+class IncrementalFetchContext(private val time: Time,
+                              private val reqMetadata: JFetchMetadata,
+                              private val session: FetchSession) extends FetchContext {
+
+  override def getFetchOffset(tp: TopicPartition): Option[Long] = session.getFetchOffset(tp)
+
+  override def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit = {
+    // Take the session lock and iterate over all the cached partitions.
+    session.synchronized {
+      session.partitionMap.iterator().asScala.foreach(part => {
+        fun(new TopicPartition(part.topic, part.partition), part.reqData())
+      })
+    }
+  }
+
+  override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = {
+    session.synchronized {
+      // Check to make sure that the session epoch didn't change in between
+      // creating this fetch context and generating this response.
+      val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch())
+      if (session.epoch != expectedEpoch) {
+        info(s"Incremental fetch session ${session.id} expected epoch $expectedEpoch, but " +
+          s"got ${session.epoch}.  Possible duplicate request.")
+        new FetchResponse(Errors.INVALID_FETCH_SESSION_EPOCH, new FetchSession.RESP_MAP, 0, session.id)
+      } else {
+        // Iterate over the update list.  Prune updates which don't need to be sent.
+        val iter = updates.entrySet().iterator()
+        while (iter.hasNext()) {
+          val entry = iter.next()
+          val topicPart = entry.getKey
+          val respData = entry.getValue
+          val cachedPart = session.partitionMap.find(new CachedPartition(topicPart))
+          val mustRespond = cachedPart.updateResponseData(respData)
+          if (mustRespond) {
+            // Move this to the end of the cached partition map.
+            // This is important for ensuring fairness when lots of partitions
+            // have data to return.
+            session.partitionMap.remove(cachedPart)
+            session.partitionMap.mustAdd(cachedPart)
+          } else {
+            // Do not include this partition in the FetchResponse.
+            iter.remove()
+          }
+        }
+        debug(s"Incremental fetch context with session id ${session.id} returning " +
+          s"${partitionsToLogString(updates.keySet())}")
+        new FetchResponse(Errors.NONE, updates, 0, session.id)
+      }
+    }
+  }
+}
+
+case class LastUsedKey(val lastUsedMs: Long,
+                       val id: Int) extends Comparable[LastUsedKey] {
+  override def compareTo(other: LastUsedKey): Int =
+    (lastUsedMs, id) compare (other.lastUsedMs, other.id)
+}
+
+case class EvictableKey(val privileged: Boolean,
+                        val size: Int,
+                        val id: Int) extends Comparable[EvictableKey] {
+  override def compareTo(other: EvictableKey): Int =
+    (privileged, size, id) compare (other.privileged, other.size, other.id)
+}
+
+/**
+  * Caches fetch sessions.
+  *
+  * See tryEvict for an explanation of the cache eviction strategy.
+  *
+  * The FetchSessionCache is thread-safe because all of its methods are synchronized.
+  * Note that individual fetch sessions have their own locks which are separate from the
+  * FetchSessionCache lock.  In order to avoid deadlock, the FetchSessionCache lock
+  * must never be acquired while an individual FetchSession lock is already held.
+  *
+  * @param maxEntries The maximum number of entries that can be in the cache.
+  * @param evictionMs The minimum time that an entry must be unused in order to be evictable.
+  */
+class FetchSessionCache(private val maxEntries: Int,
+                        private val evictionMs: Long) extends Logging with KafkaMetricsGroup {
+  private var numPartitions: Long = 0
+
+  // A map of session ID to FetchSession.
+  private val sessions = new mutable.HashMap[Int, FetchSession]
+
+  // Maps last used times to sessions.
+  private val lastUsed = new util.TreeMap[LastUsedKey, FetchSession]
+
+  // A map containing sessions which can be evicted by both privileged and
+  // unprivileged sessions.
+  private val evictableByAll = new util.TreeMap[EvictableKey, FetchSession]
+
+  // A map containing sessions which can be evicted by privileged sessions.
+  private val evictableByPrivileged = new util.TreeMap[EvictableKey, FetchSession]
+
+  // Set up metrics.
+  removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_SESSISONS)
+  newGauge(FetchSession.NUM_INCREMENTAL_FETCH_SESSISONS,
+    new Gauge[Int] {
+      def value = FetchSessionCache.this.size
+    }
+  )
+  removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED)
+  newGauge(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED,
+    new Gauge[Long] {
+      def value = FetchSessionCache.this.totalPartitions
+    }
+  )
+  removeMetric(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC)
+  val evictionsMeter = newMeter(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC,
+    FetchSession.EVICTIONS, TimeUnit.SECONDS, Map.empty)
+
+  /**
+    * Get a session by session ID.
+    *
+    * @param sessionId  The session ID.
+    * @return           The session, or None if no such session was found.
+    */
+  def get(sessionId: Int): Option[FetchSession] = synchronized {
+    sessions.get(sessionId)
+  }
+
+  /**
+    * Get the number of entries currently in the fetch session cache.
+    */
+  def size(): Int = synchronized {
+    sessions.size
+  }
+
+  /**
+    * Get the total number of cached partitions.
+    */
+  def totalPartitions(): Long = synchronized {
+    numPartitions
+  }
+
+  /**
+    * Creates a new random session ID.  The new session ID will be positive and unique on this broker.
+    *
+    * @return   The new session ID.
+    */
+  def newSessionId(): Int = synchronized {
+    var id = 0
+    do {
+      id = ThreadLocalRandom.current().nextInt(1, Int.MaxValue)
+    } while (sessions.contains(id) || id == INVALID_SESSION_ID)
+    id
+  }
+
+  /**
+    * Try to create a new session.
+    *
+    * @param now                The current time in milliseconds.
+    * @param privileged         True if the new entry we are trying to create is privileged.
+    * @param size               The number of cached partitions in the new entry we are trying to create.
+    * @param createPartitions   A callback function which creates the map of cached partitions.
+    * @return                   If we created a session, the ID; INVALID_SESSION_ID otherwise.
+    */
+  def maybeCreateSession(now: Long,
+                         privileged: Boolean,
+                         size: Int,
+                         createPartitions: () => FetchSession.CACHE_MAP): Int =
+  synchronized {
+    // If there is room, create a new session entry.
+    if ((sessions.size < maxEntries) ||
+        tryEvict(privileged, EvictableKey(privileged, size, 0), now)) {
+      val partitionMap = createPartitions()
+      val session = new FetchSession(newSessionId(), privileged, partitionMap,
+          now, now, JFetchMetadata.nextEpoch(INITIAL_EPOCH))
+      debug(s"Created fetch session ${session.toString()}")
+      sessions.put(session.id, session)
+      touch(session, now)
+      session.id
+    } else {
+      debug(s"No fetch session created for privileged=$privileged, size=$size.")
+      INVALID_SESSION_ID
+    }
+  }
+
+  /**
+    * Try to evict an entry from the session cache.
+    *
+    * A proposed new element A may evict an existing element B if:
+    * 1. A is privileged and B is not, or
+    * 2. B is considered "stale" because it has been inactive for a long time, or
+    * 3. A contains more partitions than B, and B is not recently created.
+    *
+    * @param privileged True if the new entry we would like to add is privileged.
+    * @param key        The EvictableKey for the new entry we would like to add.
+    * @param now        The current time in milliseconds.
+    * @return           True if an entry was evicted; false otherwise.
+    */
+  def tryEvict(privileged: Boolean, key: EvictableKey, now: Long): Boolean = synchronized {
+    // Try to evict an entry which is stale.
+    val lastUsedEntry = lastUsed.firstEntry()
+    if (lastUsedEntry == null) {
+      trace("There are no cache entries to evict.")
+      false
+    } else if (now - lastUsedEntry.getKey().lastUsedMs > evictionMs) {
+      val session = lastUsedEntry.getValue()
+      trace(s"Evicting stale FetchSession ${session.id}.")
+      remove(session)
+      evictionsMeter.mark()
+      true
+    } else {
+      // If there are no stale entries, check the first evictable entry.
+      // If it is less valuable than our proposed entry, evict it.
+      val map = if (privileged) evictableByPrivileged else evictableByAll
+      val evictableEntry = map.firstEntry()
+      if (evictableEntry == null) {
+        trace("No evictable entries found.")
+        false
+      } else if (key.compareTo(evictableEntry.getKey()) < 0) {
+        trace(s"Can't evict ${evictableEntry.getKey()} with ${key.toString}")
+        false
+      } else {
+        trace(s"Evicting ${evictableEntry.getKey()} with ${key.toString}.")
+        remove(evictableEntry.getValue())
+        evictionsMeter.mark()
+        true
+      }
+    }
+  }
+
+  def remove(sessionId: Int): Option[FetchSession] = synchronized {
+    get(sessionId) match {
+      case None => None
+      case Some(session) => remove(session)
+    }
+  }
+
+  /**
+    * Remove an entry from the session cache.
+    *
+    * @param session  The session.
+    *
+    * @return         The removed session, or None if there was no such session.
+    */
+  def remove(session: FetchSession): Option[FetchSession] = synchronized {
+    val evictableKey = session.synchronized {
+      lastUsed.remove(session.lastUsedKey())
+      session.evictableKey()
+    }
+    evictableByAll.remove(evictableKey)
+    evictableByPrivileged.remove(evictableKey)
+    val removeResult = sessions.remove(session.id)
+    if (removeResult.isDefined) {
+      numPartitions = numPartitions - session.cachedSize
+    }
+    removeResult
+  }
+
+  /**
+    * Update a session's position in the lastUsed and evictable trees.
+    *
+    * @param session  The session.
+    * @param now      The current time in milliseconds.
+    */
+  def touch(session: FetchSession, now: Long): Unit = synchronized {
+    session.synchronized {
+      // Update the lastUsed map.
+      lastUsed.remove(session.lastUsedKey())
+      session.lastUsedMs = now
+      lastUsed.put(session.lastUsedKey(), session)
+
+      val oldSize = session.cachedSize
+      if (oldSize != -1) {
+        val oldEvictableKey = session.evictableKey()
+        evictableByPrivileged.remove(oldEvictableKey)
+        evictableByAll.remove(oldEvictableKey)
+        numPartitions = numPartitions - oldSize
+      }
+      session.cachedSize = session.size()
+      val newEvictableKey = session.evictableKey()
+      if ((!session.privileged) || (now - session.creationMs > evictionMs)) {
+        evictableByPrivileged.put(newEvictableKey, session)
+      }
+      if (now - session.creationMs > evictionMs) {
+        evictableByAll.put(newEvictableKey, session)
+      }
+      numPartitions = numPartitions + session.cachedSize
+    }
+  }
+}
+
+class FetchManager(private val time: Time,
+                   private val cache: FetchSessionCache) extends Logging {
+  def newContext(reqMetadata: JFetchMetadata,
+                 fetchData: FetchSession.REQ_MAP,
+                 toForget: util.List[TopicPartition],
+                 isFollower: Boolean): FetchContext = {
+    val context = if (reqMetadata.isFull) {
+      var removedFetchSessionStr = ""
+      if (reqMetadata.sessionId() != INVALID_SESSION_ID) {
+        // Any session specified in a FULL fetch request will be closed.
+        if (cache.remove(reqMetadata.sessionId()).isDefined) {
+          removedFetchSessionStr = s" Removed fetch session ${reqMetadata.sessionId()}."
+        }
+      }
+      var suffix = ""
+      val context = if (reqMetadata.epoch() == FINAL_EPOCH) {
+        // If the epoch is FINAL_EPOCH, don't try to create a new session.
+        suffix = " Will not try to create a new session."
+        new SessionlessFetchContext(fetchData)
+      } else {
+        new FullFetchContext(time, cache, reqMetadata, fetchData, isFollower)
+      }
+      debug(s"Created a new full FetchContext with ${partitionsToLogString(fetchData.keySet())}."+
+        s"${removedFetchSessionStr}${suffix}")
+      context
+    } else {
+      cache.synchronized {
+        cache.get(reqMetadata.sessionId()) match {
+          case None => {
+            info(s"Created a new error FetchContext for session id ${reqMetadata.sessionId()}: " +
+              "no such session ID found.")
+            new SessionErrorContext(Errors.FETCH_SESSION_ID_NOT_FOUND, reqMetadata)
+          }
+          case Some(session) => session.synchronized {
+            if (session.epoch != reqMetadata.epoch()) {
+              debug(s"Created a new error FetchContext for session id ${session.id}: expected " +
+                s"epoch ${session.epoch}, but got epoch ${reqMetadata.epoch()}.")
+              new SessionErrorContext(Errors.INVALID_FETCH_SESSION_EPOCH, reqMetadata)
+            } else {
+              val (added, updated, removed) = session.update(fetchData, toForget, reqMetadata)
+              if (session.isEmpty) {
+                debug(s"Created a new sessionless FetchContext and closing session id ${session.id}, " +
+                  s"epoch ${session.epoch}: after removing ${partitionsToLogString(removed)}, " +
+                  s"there are no more partitions left.")
+                cache.remove(session)
+                new SessionlessFetchContext(fetchData)
+              } else {
+                if (session.size() != session.cachedSize) {
+                  // If the number of partitions in the session changed, update the session's
+                  // position in the cache.
+                  cache.touch(session, session.lastUsedMs)
+                }
+                session.epoch = JFetchMetadata.nextEpoch(session.epoch)
+                debug(s"Created a new incremental FetchContext for session id ${session.id}, " +
+                  s"epoch ${session.epoch}: added ${partitionsToLogString(added)}, " +
+                  s"updated ${partitionsToLogString(updated)}, " +
+                  s"removed ${partitionsToLogString(removed)}")
+                new IncrementalFetchContext(time, reqMetadata, session)
+              }
+            }
+          }
+        }
+      }
+    }
+    context
+  }
+
+  def partitionsToLogString(partitions: util.Collection[TopicPartition]): String =
+    FetchSession.partitionsToLogString(partitions, isTraceEnabled)
+}
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index 1f448af..b84587f 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -80,6 +80,7 @@ class KafkaApis(val requestChannel: RequestChannel,
                 val metrics: Metrics,
                 val authorizer: Option[Authorizer],
                 val quotas: QuotaManagers,
+                val fetchManager: FetchManager,
                 brokerTopicStats: BrokerTopicStats,
                 val clusterId: String,
                 time: Time,
@@ -481,35 +482,52 @@ class KafkaApis(val requestChannel: RequestChannel,
    * Handle a fetch request
    */
   def handleFetchRequest(request: RequestChannel.Request) {
-    val fetchRequest = request.body[FetchRequest]
     val versionId = request.header.apiVersion
     val clientId = request.header.clientId
-
-    val unauthorizedTopicResponseData = mutable.ArrayBuffer[(TopicPartition, FetchResponse.PartitionData)]()
-    val nonExistingTopicResponseData = mutable.ArrayBuffer[(TopicPartition, FetchResponse.PartitionData)]()
-    val authorizedRequestInfo = mutable.ArrayBuffer[(TopicPartition, FetchRequest.PartitionData)]()
-
-    if (fetchRequest.isFromFollower() && !authorize(request.session, ClusterAction, Resource.ClusterResource))
-      for (topicPartition <- fetchRequest.fetchData.asScala.keys)
-        unauthorizedTopicResponseData += topicPartition -> new FetchResponse.PartitionData(Errors.CLUSTER_AUTHORIZATION_FAILED,
-          FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
-          FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
-    else
-      for ((topicPartition, partitionData) <- fetchRequest.fetchData.asScala) {
-        if (!authorize(request.session, Read, new Resource(Topic, topicPartition.topic)))
-          unauthorizedTopicResponseData += topicPartition -> new FetchResponse.PartitionData(Errors.TOPIC_AUTHORIZATION_FAILED,
+    val fetchRequest = request.body[FetchRequest]
+    val fetchContext = fetchManager.newContext(fetchRequest.metadata(),
+          fetchRequest.fetchData(),
+          fetchRequest.toForget(),
+          fetchRequest.isFromFollower())
+
+    val erroneous = mutable.ArrayBuffer[(TopicPartition, FetchResponse.PartitionData)]()
+    val interesting = mutable.ArrayBuffer[(TopicPartition, FetchRequest.PartitionData)]()
+    if (fetchRequest.isFromFollower()) {
+      // The follower must have ClusterAction on ClusterResource in order to fetch partition data.
+      if (authorize(request.session, ClusterAction, Resource.ClusterResource)) {
+        fetchContext.foreachPartition((part, data) => {
+          if (!metadataCache.contains(part.topic)) {
+            erroneous += part -> new FetchResponse.PartitionData(Errors.UNKNOWN_TOPIC_OR_PARTITION,
+              FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
+              FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
+          } else {
+            interesting += (part -> data)
+          }
+        })
+      } else {
+        fetchContext.foreachPartition((part, data) => {
+          erroneous += part -> new FetchResponse.PartitionData(Errors.TOPIC_AUTHORIZATION_FAILED,
             FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
             FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
-        else if (!metadataCache.contains(topicPartition.topic))
-          nonExistingTopicResponseData += topicPartition -> new FetchResponse.PartitionData(Errors.UNKNOWN_TOPIC_OR_PARTITION,
+        })
+      }
+    } else {
+      // Regular Kafka consumers need READ permission on each partition they are fetching.
+      fetchContext.foreachPartition((part, data) => {
+        if (!authorize(request.session, Read, new Resource(Topic, part.topic)))
+          erroneous += part -> new FetchResponse.PartitionData(Errors.TOPIC_AUTHORIZATION_FAILED,
+            FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
+            FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
+        else if (!metadataCache.contains(part.topic))
+          erroneous += part -> new FetchResponse.PartitionData(Errors.UNKNOWN_TOPIC_OR_PARTITION,
             FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET,
             FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)
         else
-          authorizedRequestInfo += (topicPartition -> partitionData)
-      }
+          interesting += (part -> data)
+      })
+    }
 
     def convertedPartitionData(tp: TopicPartition, data: FetchResponse.PartitionData) = {
-
       // Down-conversion of the fetched records is needed when the stored magic version is
       // greater than that supported by the client (as indicated by the fetch request version). If the
       // configured magic version for the topic is less than or equal to that supported by the version of the
@@ -529,7 +547,7 @@ class KafkaApis(val requestChannel: RequestChannel,
 
         downConvertMagic.map { magic =>
           trace(s"Down converting records from partition $tp to message format version $magic for fetch request from $clientId")
-          val converted = data.records.downConvert(magic, fetchRequest.fetchData.get(tp).fetchOffset, time)
+          val converted = data.records.downConvert(magic, fetchContext.getFetchOffset(tp).get, time)
           updateRecordsProcessingStats(request, tp, converted.recordsProcessingStats)
           new FetchResponse.PartitionData(data.error, data.highWatermark, FetchResponse.INVALID_LAST_STABLE_OFFSET,
             data.logStartOffset, data.abortedTransactions, converted.records)
@@ -540,34 +558,28 @@ class KafkaApis(val requestChannel: RequestChannel,
 
     // the callback for process a fetch response, invoked before throttling
     def processResponseCallback(responsePartitionData: Seq[(TopicPartition, FetchPartitionData)]) {
-      val partitionData = {
-        responsePartitionData.map { case (tp, data) =>
-          val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull
-          val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
-          tp -> new FetchResponse.PartitionData(data.error, data.highWatermark, lastStableOffset,
-            data.logStartOffset, abortedTransactions, data.records)
-        }
-      }
-
-      val mergedPartitionData = partitionData ++ unauthorizedTopicResponseData ++ nonExistingTopicResponseData
-      val fetchedPartitionData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]()
-
-      mergedPartitionData.foreach { case (topicPartition, data) =>
-        if (data.error != Errors.NONE)
-          debug(s"Fetch request with correlation id ${request.header.correlationId} from client $clientId " +
-            s"on partition $topicPartition failed due to ${data.error.exceptionName}")
-
-        fetchedPartitionData.put(topicPartition, data)
+      val partitions = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
+      responsePartitionData.foreach{ case (tp, data) =>
+        val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull
+        val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
+        partitions.put(tp, new FetchResponse.PartitionData(data.error, data.highWatermark, lastStableOffset,
+          data.logStartOffset, abortedTransactions, data.records))
       }
+      erroneous.foreach{case (tp, data) => partitions.put(tp, data)}
+      val unconvertedFetchResponse = fetchContext.updateAndGenerateResponseData(partitions)
 
       // fetch response callback invoked after any throttling
       def fetchResponseCallback(bandwidthThrottleTimeMs: Int) {
         def createResponse(requestThrottleTimeMs: Int): FetchResponse = {
           val convertedData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
-          fetchedPartitionData.asScala.foreach { case (tp, partitionData) =>
+          unconvertedFetchResponse.responseData().asScala.foreach { case (tp, partitionData) =>
+            if (partitionData.error != Errors.NONE)
+              debug(s"Fetch request with correlation id ${request.header.correlationId} from client $clientId " +
+                s"on partition $tp failed due to ${partitionData.error.exceptionName}")
             convertedData.put(tp, convertedPartitionData(tp, partitionData))
           }
-          val response = new FetchResponse(convertedData, bandwidthThrottleTimeMs + requestThrottleTimeMs)
+          val response = new FetchResponse(unconvertedFetchResponse.error(), convertedData,
+            bandwidthThrottleTimeMs + requestThrottleTimeMs, unconvertedFetchResponse.sessionId())
           response.responseData.asScala.foreach { case (topicPartition, data) =>
             // record the bytes out metrics only when the response is being sent
             brokerTopicStats.updateBytesOut(topicPartition.topic, fetchRequest.isFromFollower, data.records.sizeInBytes)
@@ -575,6 +587,9 @@ class KafkaApis(val requestChannel: RequestChannel,
           response
         }
 
+        trace(s"Sending Fetch response with partitions.size=${unconvertedFetchResponse.responseData().size()}, " +
+          s"metadata=${unconvertedFetchResponse.sessionId()}")
+
         if (fetchRequest.isFromFollower)
           sendResponseExemptThrottle(request, createResponse(0))
         else
@@ -587,21 +602,20 @@ class KafkaApis(val requestChannel: RequestChannel,
 
       if (fetchRequest.isFromFollower) {
         // We've already evaluated against the quota and are good to go. Just need to record it now.
-        val responseSize = sizeOfThrottledPartitions(versionId, fetchRequest, mergedPartitionData, quotas.leader)
+        val responseSize = sizeOfThrottledPartitions(versionId, unconvertedFetchResponse, quotas.leader)
         quotas.leader.record(responseSize)
         fetchResponseCallback(bandwidthThrottleTimeMs = 0)
       } else {
         // Fetch size used to determine throttle time is calculated before any down conversions.
         // This may be slightly different from the actual response size. But since down conversions
         // result in data being loaded into memory, it is better to do this after throttling to avoid OOM.
-        val response = new FetchResponse(fetchedPartitionData, 0)
-        val responseStruct = response.toStruct(versionId)
+        val responseStruct = unconvertedFetchResponse.toStruct(versionId)
         quotas.fetch.maybeRecordAndThrottle(request.session.sanitizedUser, clientId, responseStruct.sizeOf,
           fetchResponseCallback)
       }
     }
 
-    if (authorizedRequestInfo.isEmpty)
+    if (interesting.isEmpty)
       processResponseCallback(Seq.empty)
     else {
       // call the replica manager to fetch messages from the local replica
@@ -611,23 +625,45 @@ class KafkaApis(val requestChannel: RequestChannel,
         fetchRequest.minBytes,
         fetchRequest.maxBytes,
         versionId <= 2,
-        authorizedRequestInfo,
+        interesting,
         replicationQuota(fetchRequest),
         processResponseCallback,
         fetchRequest.isolationLevel)
     }
   }
 
+  class SelectingIterator(val partitions: util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData],
+                          val quota: ReplicationQuotaManager)
+                          extends util.Iterator[util.Map.Entry[TopicPartition, FetchResponse.PartitionData]] {
+    val iter = partitions.entrySet().iterator()
+
+    var nextElement: util.Map.Entry[TopicPartition, FetchResponse.PartitionData] = null
+
+    override def hasNext: Boolean = {
+      while ((nextElement == null) && iter.hasNext()) {
+        val element = iter.next()
+        if (quota.isThrottled(element.getKey)) {
+          nextElement = element
+        }
+      }
+      nextElement != null
+    }
+
+    override def next(): util.Map.Entry[TopicPartition, FetchResponse.PartitionData] = {
+      if (!hasNext()) throw new NoSuchElementException()
+      val element = nextElement
+      nextElement = null
+      element
+    }
+
+    override def remove() = throw new UnsupportedOperationException()
+  }
+
   private def sizeOfThrottledPartitions(versionId: Short,
-                                        fetchRequest: FetchRequest,
-                                        mergedPartitionData: Seq[(TopicPartition, FetchResponse.PartitionData)],
+                                        unconvertedResponse: FetchResponse,
                                         quota: ReplicationQuotaManager): Int = {
-    val partitionData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
-    mergedPartitionData.foreach { case (tp, data) =>
-      if (quota.isThrottled(tp))
-        partitionData.put(tp, data)
-    }
-    FetchResponse.sizeOf(versionId, partitionData)
+    val iter = new SelectingIterator(unconvertedResponse.responseData(), quota)
+    FetchResponse.sizeOf(versionId, iter)
   }
 
   def replicationQuota(fetchRequest: FetchRequest): ReplicaQuota =
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala
index 64698f7..0b9bdaa 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -174,6 +174,9 @@ object Defaults {
   val TransactionsAbortTimedOutTransactionsCleanupIntervalMS = TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs
   val TransactionsRemoveExpiredTransactionsCleanupIntervalMS = TransactionStateManager.DefaultRemoveExpiredTransactionalIdsIntervalMs
 
+  /** ********* Fetch Session Configuration **************/
+  val MaxIncrementalFetchSessionCacheSlots = 1000
+
   /** ********* Quota Configuration ***********/
   val ProducerQuotaBytesPerSecondDefault = ClientQuotaManagerConfig.QuotaBytesPerSecondDefault
   val ConsumerQuotaBytesPerSecondDefault = ClientQuotaManagerConfig.QuotaBytesPerSecondDefault
@@ -375,6 +378,9 @@ object KafkaConfig {
   val TransactionsAbortTimedOutTransactionCleanupIntervalMsProp = "transaction.abort.timed.out.transaction.cleanup.interval.ms"
   val TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp = "transaction.remove.expired.transaction.cleanup.interval.ms"
 
+  /** ********* Fetch Session Configuration **************/
+  val MaxIncrementalFetchSessionCacheSlots = "max.incremental.fetch.session.cache.slots"
+
   /** ********* Quota Configuration ***********/
   val ProducerQuotaBytesPerSecondDefaultProp = "quota.producer.default"
   val ConsumerQuotaBytesPerSecondDefaultProp = "quota.consumer.default"
@@ -652,6 +658,9 @@ object KafkaConfig {
   val TransactionsAbortTimedOutTransactionsIntervalMsDoc = "The interval at which to rollback transactions that have timed out"
   val TransactionsRemoveExpiredTransactionsIntervalMsDoc = "The interval at which to remove transactions that have expired due to <code>transactional.id.expiration.ms<code> passing"
 
+  /** ********* Fetch Session Configuration **************/
+  val MaxIncrementalFetchSessionCacheSlotsDoc = "The maximum number of incremental fetch sessions that we will maintain."
+
   /** ********* Quota Configuration ***********/
   val ProducerQuotaBytesPerSecondDefaultDoc = "DEPRECATED: Used only when dynamic default quotas are not configured for <user>, <client-id> or <user, client-id> in Zookeeper. " +
   "Any producer distinguished by clientId will get throttled if it produces more bytes than this value per-second"
@@ -886,6 +895,9 @@ object KafkaConfig {
       .define(TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, INT, Defaults.TransactionsAbortTimedOutTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsAbortTimedOutTransactionsIntervalMsDoc)
       .define(TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp, INT, Defaults.TransactionsRemoveExpiredTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsRemoveExpiredTransactionsIntervalMsDoc)
 
+    /** ********* Fetch Session Configuration **************/
+      .define(MaxIncrementalFetchSessionCacheSlots, INT, Defaults.MaxIncrementalFetchSessionCacheSlots, atLeast(0), MEDIUM, MaxIncrementalFetchSessionCacheSlotsDoc)
+
       /** ********* Kafka Metrics Configuration ***********/
       .define(MetricNumSamplesProp, INT, Defaults.MetricNumSamples, atLeast(1), LOW, MetricNumSamplesDoc)
       .define(MetricSampleWindowMsProp, LONG, Defaults.MetricSampleWindowMs, atLeast(1), LOW, MetricSampleWindowMsDoc)
@@ -1196,6 +1208,9 @@ class KafkaConfig(val props: java.util.Map[_, _], doLog: Boolean, dynamicConfigO
   /** ********* Transaction Configuration **************/
   val transactionIdExpirationMs = getInt(KafkaConfig.TransactionalIdExpirationMsProp)
 
+  /** ********* Fetch Session Configuration **************/
+  val maxIncrementalFetchSessionCacheSlots = getInt(KafkaConfig.MaxIncrementalFetchSessionCacheSlots)
+
   val deleteTopicEnable = getBoolean(KafkaConfig.DeleteTopicEnableProp)
   def compressionType = getString(KafkaConfig.CompressionTypeProp)
 
diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala
index 0212181..d7ca656 100755
--- a/core/src/main/scala/kafka/server/KafkaServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaServer.scala
@@ -90,6 +90,7 @@ object KafkaServer {
       .timeWindow(kafkaConfig.metricSampleWindowMs, TimeUnit.MILLISECONDS)
   }
 
+  val MIN_INCREMENTAL_FETCH_SESSION_EVICTION_MS: Long = 120000
 }
 
 /**
@@ -282,10 +283,14 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP
           authZ
         }
 
+        val fetchManager = new FetchManager(Time.SYSTEM,
+          new FetchSessionCache(config.maxIncrementalFetchSessionCacheSlots,
+            KafkaServer.MIN_INCREMENTAL_FETCH_SESSION_EVICTION_MS))
+
         /* start processing requests */
         apis = new KafkaApis(socketServer.requestChannel, replicaManager, adminManager, groupCoordinator, transactionCoordinator,
           kafkaController, zkClient, config.brokerId, config, metadataCache, metrics, authorizer, quotaManagers,
-          brokerTopicStats, clusterId, time, tokenManager)
+          fetchManager, brokerTopicStats, clusterId, time, tokenManager)
 
         requestHandlerPool = new KafkaRequestHandlerPool(config.brokerId, socketServer.requestChannel, apis, time,
           config.numIoThreads)
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index da94c4a..8344d5b 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -26,6 +26,7 @@ import kafka.log.LogConfig
 import kafka.server.ReplicaFetcherThread._
 import kafka.server.epoch.LeaderEpochCache
 import kafka.zk.AdminZkClient
+import org.apache.kafka.clients.FetchSessionHandler
 import org.apache.kafka.common.requests.EpochEndOffset._
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.KafkaStorageException
@@ -35,6 +36,7 @@ import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.requests.{EpochEndOffset, FetchResponse, ListOffsetRequest, ListOffsetResponse, OffsetsForLeaderEpochRequest, OffsetsForLeaderEpochResponse, FetchRequest => JFetchRequest}
 import org.apache.kafka.common.utils.{LogContext, Time}
+
 import scala.collection.JavaConverters._
 import scala.collection.{Map, mutable}
 
@@ -65,17 +67,20 @@ class ReplicaFetcherThread(name: String,
     new ReplicaFetcherBlockingSend(sourceBroker, brokerConfig, metrics, time, fetcherId,
       s"broker-$replicaId-fetcher-$fetcherId", logContext))
   private val fetchRequestVersion: Short =
-    if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV1) 5
+    if (brokerConfig.interBrokerProtocolVersion >= KAFKA_1_1_IV0) 7
+    else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV1) 5
     else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV0) 4
     else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_1_IV1) 3
     else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_0_IV0) 2
     else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_9_0) 1
     else 0
+  private val fetchMetadataSupported = brokerConfig.interBrokerProtocolVersion >= KAFKA_1_1_IV0
   private val maxWait = brokerConfig.replicaFetchWaitMaxMs
   private val minBytes = brokerConfig.replicaFetchMinBytes
   private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes
   private val fetchSize = brokerConfig.replicaFetchMaxBytes
   private val shouldSendLeaderEpochRequest: Boolean = brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV2
+  private val fetchSessionHandler = new FetchSessionHandler(logContext, sourceBroker.id)
 
   private def epochCacheOpt(tp: TopicPartition): Option[LeaderEpochCache] =  replicaMgr.getReplica(tp).map(_.epochs.get)
 
@@ -211,10 +216,20 @@ class ReplicaFetcherThread(name: String,
   }
 
   protected def fetch(fetchRequest: FetchRequest): Seq[(TopicPartition, PartitionData)] = {
-    val clientResponse = leaderEndpoint.sendRequest(fetchRequest.underlying)
-    val fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse]
-    fetchResponse.responseData.asScala.toSeq.map { case (key, value) =>
-      key -> new PartitionData(value)
+    try {
+      val clientResponse = leaderEndpoint.sendRequest(fetchRequest.underlying)
+      val fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse]
+      if (!fetchSessionHandler.handleResponse(fetchResponse)) {
+        Nil
+      } else {
+        fetchResponse.responseData.asScala.toSeq.map { case (key, value) =>
+          key -> new PartitionData(value)
+        }
+      }
+    } catch {
+      case t: Throwable =>
+        fetchSessionHandler.handleError(t)
+        throw t
     }
   }
 
@@ -240,15 +255,16 @@ class ReplicaFetcherThread(name: String,
   }
 
   override def buildFetchRequest(partitionMap: Seq[(TopicPartition, PartitionFetchState)]): ResultWithPartitions[FetchRequest] = {
-    val requestMap = new util.LinkedHashMap[TopicPartition, JFetchRequest.PartitionData]
     val partitionsWithError = mutable.Set[TopicPartition]()
 
+    val builder = fetchSessionHandler.newBuilder()
     partitionMap.foreach { case (topicPartition, partitionFetchState) =>
       // We will not include a replica in the fetch request if it should be throttled.
       if (partitionFetchState.isReadyForFetch && !shouldFollowerThrottle(quota, topicPartition)) {
         try {
           val logStartOffset = replicaMgr.getReplicaOrException(topicPartition).logStartOffset
-          requestMap.put(topicPartition, new JFetchRequest.PartitionData(partitionFetchState.fetchOffset, logStartOffset, fetchSize))
+          builder.add(topicPartition, new JFetchRequest.PartitionData(
+            partitionFetchState.fetchOffset, logStartOffset, fetchSize))
         } catch {
           case _: KafkaStorageException =>
             // The replica has already been marked offline due to log directory failure and the original failure should have already been logged.
@@ -258,9 +274,15 @@ class ReplicaFetcherThread(name: String,
       }
     }
 
-    val requestBuilder = JFetchRequest.Builder.forReplica(fetchRequestVersion, replicaId, maxWait, minBytes, requestMap)
-      .setMaxBytes(maxBytes)
-    ResultWithPartitions(new FetchRequest(requestBuilder), partitionsWithError)
+    val fetchData = builder.build()
+    val requestBuilder = JFetchRequest.Builder.
+      forReplica(fetchRequestVersion, replicaId, maxWait, minBytes, fetchData.toSend())
+        .setMaxBytes(maxBytes)
+        .toForget(fetchData.toForget)
+    if (fetchMetadataSupported) {
+      requestBuilder.metadata(fetchData.metadata())
+    }
+    ResultWithPartitions(new FetchRequest(fetchData.sessionPartitions(), requestBuilder), partitionsWithError)
   }
 
   /**
@@ -365,10 +387,12 @@ class ReplicaFetcherThread(name: String,
 
 object ReplicaFetcherThread {
 
-  private[server] class FetchRequest(val underlying: JFetchRequest.Builder) extends AbstractFetcherThread.FetchRequest {
-    def isEmpty: Boolean = underlying.fetchData.isEmpty
+  private[server] class FetchRequest(val sessionParts: util.Map[TopicPartition, JFetchRequest.PartitionData],
+                                     val underlying: JFetchRequest.Builder)
+      extends AbstractFetcherThread.FetchRequest {
     def offset(topicPartition: TopicPartition): Long =
-      underlying.fetchData.asScala(topicPartition).fetchOffset
+      sessionParts.get(topicPartition).fetchOffset
+    override def isEmpty = sessionParts.isEmpty && underlying.toForget().isEmpty
     override def toString = underlying.toString
   }
 
diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala
index 9090fda..f2b3552 100644
--- a/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala
+++ b/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala
@@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.{Record, RecordBatch}
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse}
+import org.apache.kafka.common.requests.{FetchMetadata => JFetchMetadata}
 import org.apache.kafka.common.serialization.{ByteArraySerializer, StringSerializer}
 import org.junit.Assert._
 import org.junit.Test
@@ -294,6 +295,60 @@ class FetchRequestTest extends BaseRequestTest {
       expectedMagic = RecordBatch.MAGIC_VALUE_V2)
   }
 
+  /**
+    * Test that when an incremental fetch session contains partitions with an error,
+    * those partitions are returned in all incremental fetch requests.
+    */
+  @Test
+  def testCreateIncrementalFetchWithPartitionsInError(): Unit = {
+    def createFetchRequest(topicPartitions: Seq[TopicPartition],
+                           metadata: JFetchMetadata,
+                           toForget: Seq[TopicPartition]): FetchRequest =
+      FetchRequest.Builder.forConsumer(Int.MaxValue, 0,
+        createPartitionMap(Integer.MAX_VALUE, topicPartitions, Map.empty))
+          .toForget(toForget.asJava)
+          .metadata(metadata)
+          .build()
+    val foo0 = new TopicPartition("foo", 0)
+    val foo1 = new TopicPartition("foo", 1)
+    createTopic("foo", Map(0 -> List(0, 1), 1 -> List(0, 2)))
+    val bar0 = new TopicPartition("bar", 0)
+    val req1 = createFetchRequest(List(foo0, foo1, bar0), JFetchMetadata.INITIAL, Nil)
+    val resp1 = sendFetchRequest(0, req1)
+    assertEquals(Errors.NONE, resp1.error())
+    assertTrue("Expected the broker to create a new incremental fetch session", resp1.sessionId() > 0)
+    debug(s"Test created an incremental fetch session ${resp1.sessionId}")
+    assertTrue(resp1.responseData().containsKey(foo0))
+    assertTrue(resp1.responseData().containsKey(foo1))
+    assertTrue(resp1.responseData().containsKey(bar0))
+    assertEquals(Errors.NONE, resp1.responseData().get(foo0).error)
+    assertEquals(Errors.NONE, resp1.responseData().get(foo1).error)
+    assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, resp1.responseData().get(bar0).error)
+    val req2 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 1), Nil)
+    val resp2 = sendFetchRequest(0, req2)
+    assertEquals(Errors.NONE, resp2.error())
+    assertEquals("Expected the broker to continue the incremental fetch session",
+      resp1.sessionId(), resp2.sessionId())
+    assertFalse(resp2.responseData().containsKey(foo0))
+    assertFalse(resp2.responseData().containsKey(foo1))
+    assertTrue(resp2.responseData().containsKey(bar0))
+    assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, resp2.responseData().get(bar0).error)
+    createTopic("bar", Map(0 -> List(0, 1)))
+    val req3 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 2), Nil)
+    val resp3 = sendFetchRequest(0, req3)
+    assertEquals(Errors.NONE, resp3.error())
+    assertFalse(resp3.responseData().containsKey(foo0))
+    assertFalse(resp3.responseData().containsKey(foo1))
+    assertTrue(resp3.responseData().containsKey(bar0))
+    assertEquals(Errors.NONE, resp3.responseData().get(bar0).error)
+    val req4 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 3), Nil)
+    val resp4 = sendFetchRequest(0, req4)
+    assertEquals(Errors.NONE, resp4.error())
+    assertFalse(resp4.responseData().containsKey(foo0))
+    assertFalse(resp4.responseData().containsKey(foo1))
+    assertFalse(resp4.responseData().containsKey(bar0))
+  }
+
   private def records(partitionData: FetchResponse.PartitionData): Seq[Record] = {
     partitionData.records.records.asScala.toIndexedSeq
   }
diff --git a/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala
new file mode 100755
index 0000000..3320b63
--- /dev/null
+++ b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala
@@ -0,0 +1,312 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+package kafka.server
+
+import java.util
+import java.util.Collections
+
+import kafka.utils.MockTime
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.protocol.Errors
+import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INITIAL_EPOCH, INVALID_SESSION_ID}
+import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata}
+import org.junit.{Rule, Test}
+import org.junit.Assert._
+import org.junit.rules.Timeout
+
+class FetchSessionTest {
+  @Rule
+  def globalTimeout = Timeout.millis(120000)
+
+  @Test
+  def testNewSessionId(): Unit = {
+    val cache = new FetchSessionCache(3, 100)
+    for (i <- 0 to 10000) {
+      val id = cache.newSessionId()
+      assertTrue(id > 0)
+    }
+  }
+
+  def assertCacheContains(cache: FetchSessionCache, sessionIds: Int*) = {
+    var i = 0
+    for (sessionId <- sessionIds) {
+      i = i + 1
+      assertTrue("Missing session " + i + " out of " + sessionIds.size + "(" + sessionId + ")",
+        cache.get(sessionId).isDefined)
+    }
+    assertEquals(sessionIds.size, cache.size())
+  }
+
+  private def dummyCreate(size: Int)() = {
+    val cacheMap = new FetchSession.CACHE_MAP(size)
+    for (i <- 0 to (size - 1)) {
+      cacheMap.add(new CachedPartition("test", i))
+    }
+    cacheMap
+  }
+
+  @Test
+  def testSessionCache(): Unit = {
+    val cache = new FetchSessionCache(3, 100)
+    assertEquals(0, cache.size())
+    val id1 = cache.maybeCreateSession(0, false, 10, dummyCreate(10))
+    val id2 = cache.maybeCreateSession(10, false, 20, dummyCreate(20))
+    val id3 = cache.maybeCreateSession(20, false, 30, dummyCreate(30))
+    assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(30, false, 40, dummyCreate(40)))
+    assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(40, false, 5, dummyCreate(5)))
+    assertCacheContains(cache, id1, id2, id3)
+    cache.touch(cache.get(id1).get, 200)
+    val id4 = cache.maybeCreateSession(210, false, 11, dummyCreate(11))
+    assertCacheContains(cache, id1, id3, id4)
+    cache.touch(cache.get(id1).get, 400)
+    cache.touch(cache.get(id3).get, 390)
+    cache.touch(cache.get(id4).get, 400)
+    val id5 = cache.maybeCreateSession(410, false, 50, dummyCreate(50))
+    assertCacheContains(cache, id3, id4, id5)
+    assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(410, false, 5, dummyCreate(5)))
+    val id6 = cache.maybeCreateSession(410, true, 5, dummyCreate(5))
+    assertCacheContains(cache, id3, id5, id6)
+  }
+
+  @Test
+  def testResizeCachedSessions(): Unit = {
+    val cache = new FetchSessionCache(2, 100)
+    assertEquals(0, cache.totalPartitions())
+    assertEquals(0, cache.size())
+    assertEquals(0, cache.evictionsMeter.count())
+    val id1 = cache.maybeCreateSession(0, false, 2, dummyCreate(2))
+    assertTrue(id1 > 0)
+    assertCacheContains(cache, id1)
+    val session1 = cache.get(id1).get
+    assertEquals(2, session1.size())
+    assertEquals(2, cache.totalPartitions())
+    assertEquals(1, cache.size())
+    assertEquals(0, cache.evictionsMeter.count())
+    val id2 = cache.maybeCreateSession(0, false, 4, dummyCreate(4))
+    val session2 = cache.get(id2).get
+    assertTrue(id2 > 0)
+    assertCacheContains(cache, id1, id2)
+    assertEquals(6, cache.totalPartitions())
+    assertEquals(2, cache.size())
+    assertEquals(0, cache.evictionsMeter.count())
+    cache.touch(session1, 200)
+    cache.touch(session2, 200)
+    val id3 = cache.maybeCreateSession(200, false, 5, dummyCreate(5))
+    assertTrue(id3 > 0)
+    assertCacheContains(cache, id2, id3)
+    assertEquals(9, cache.totalPartitions())
+    assertEquals(2, cache.size())
+    assertEquals(1, cache.evictionsMeter.count())
+    cache.remove(id3)
+    assertCacheContains(cache, id2)
+    assertEquals(1, cache.size())
+    assertEquals(1, cache.evictionsMeter.count())
+    assertEquals(4, cache.totalPartitions())
+    val iter = session2.partitionMap.iterator()
+    iter.next()
+    iter.remove()
+    assertEquals(3, session2.size())
+    assertEquals(4, session2.cachedSize)
+    cache.touch(session2, session2.lastUsedMs)
+    assertEquals(3, cache.totalPartitions())
+  }
+
+  val EMPTY_PART_LIST = Collections.unmodifiableList(new util.ArrayList[TopicPartition]())
+
+  @Test
+  def testFetchRequests(): Unit = {
+    val time = new MockTime()
+    val cache = new FetchSessionCache(10, 1000)
+    val fetchManager = new FetchManager(time, cache)
+
+    // Verify that SESSIONLESS requests get a SessionlessFetchContext
+    val context = fetchManager.newContext(JFetchMetadata.LEGACY,
+        new util.HashMap[TopicPartition, FetchRequest.PartitionData](), EMPTY_PART_LIST, true)
+    assertEquals(classOf[SessionlessFetchContext], context.getClass)
+
+    // Create a new fetch session with a FULL fetch request
+    val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+    reqData2.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100))
+    reqData2.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100))
+    val context2 = fetchManager.newContext(JFetchMetadata.INITIAL, reqData2, EMPTY_PART_LIST, false)
+    assertEquals(classOf[FullFetchContext], context2.getClass)
+    val reqData2Iter = reqData2.entrySet().iterator()
+    context2.foreachPartition((topicPart, data) => {
+      val entry = reqData2Iter.next()
+      assertEquals(entry.getKey, topicPart)
+      assertEquals(entry.getValue, data)
+    })
+    assertEquals(0, context2.getFetchOffset(new TopicPartition("foo", 0)).get)
+    assertEquals(10, context2.getFetchOffset(new TopicPartition("foo", 1)).get)
+    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
+    respData2.put(new TopicPartition("foo", 0), new FetchResponse.PartitionData(
+      Errors.NONE, 100, 100, 100, null, null))
+    respData2.put(new TopicPartition("foo", 1), new FetchResponse.PartitionData(
+      Errors.NONE, 10, 10, 10, null, null))
+    val resp2 = context2.updateAndGenerateResponseData(respData2)
+    assertEquals(Errors.NONE, resp2.error())
+    assertTrue(resp2.sessionId() != INVALID_SESSION_ID)
+    assertEquals(respData2, resp2.responseData())
+
+    // Test trying to create a new session with an invalid epoch
+    val context3 = fetchManager.newContext(
+      new JFetchMetadata(resp2.sessionId(), 5), reqData2, EMPTY_PART_LIST, false)
+    assertEquals(classOf[SessionErrorContext], context3.getClass)
+    assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH,
+      context3.updateAndGenerateResponseData(respData2).error())
+
+    // Test trying to create a new session with a non-existent session id
+    val context4 = fetchManager.newContext(
+      new JFetchMetadata(resp2.sessionId() + 1, 1), reqData2, EMPTY_PART_LIST, false)
+    assertEquals(classOf[SessionErrorContext], context4.getClass)
+    assertEquals(Errors.FETCH_SESSION_ID_NOT_FOUND,
+      context4.updateAndGenerateResponseData(respData2).error())
+
+    // Continue the first fetch session we created.
+    val reqData5 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+    val context5 = fetchManager.newContext(
+      new JFetchMetadata(resp2.sessionId(), 1), reqData5, EMPTY_PART_LIST, false)
+    assertEquals(classOf[IncrementalFetchContext], context5.getClass)
+    val reqData5Iter = reqData2.entrySet().iterator()
+    context5.foreachPartition((topicPart, data) => {
+      val entry = reqData5Iter.next()
+      assertEquals(entry.getKey, topicPart)
+      assertEquals(entry.getValue, data)
+    })
+    assertEquals(10, context5.getFetchOffset(new TopicPartition("foo", 1)).get)
+    val resp5 = context5.updateAndGenerateResponseData(respData2)
+    assertEquals(Errors.NONE, resp5.error())
+    assertEquals(resp2.sessionId(), resp5.sessionId())
+    assertEquals(0, resp5.responseData().size())
+
+    // Test setting an invalid fetch session epoch.
+    val context6 = fetchManager.newContext(
+      new JFetchMetadata(resp2.sessionId(), 5), reqData2, EMPTY_PART_LIST, false)
+    assertEquals(classOf[SessionErrorContext], context6.getClass)
+    assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH,
+      context6.updateAndGenerateResponseData(respData2).error())
+
+    // Close the incremental fetch session.
+    var prevSessionId = resp5.sessionId()
+    var nextSessionId = prevSessionId
+    do {
+      val reqData7 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+      reqData7.put(new TopicPartition("bar", 0), new FetchRequest.PartitionData(0, 0, 100))
+      reqData7.put(new TopicPartition("bar", 1), new FetchRequest.PartitionData(10, 0, 100))
+      val context7 = fetchManager.newContext(
+        new JFetchMetadata(prevSessionId, FINAL_EPOCH), reqData7, EMPTY_PART_LIST, false)
+      assertEquals(classOf[SessionlessFetchContext], context7.getClass)
+      assertEquals(0, cache.size())
+      val respData7 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
+      respData7.put(new TopicPartition("bar", 0),
+        new FetchResponse.PartitionData(Errors.NONE, 100, 100, 100, null, null))
+      respData7.put(new TopicPartition("bar", 1),
+        new FetchResponse.PartitionData(Errors.NONE, 100, 100, 100, null, null))
+      val resp7 = context7.updateAndGenerateResponseData(respData7)
+      assertEquals(Errors.NONE, resp7.error())
+      nextSessionId = resp7.sessionId()
+    } while (nextSessionId == prevSessionId)
+  }
+
+  @Test
+  def testIncrementalFetchSession(): Unit = {
+    val time = new MockTime()
+    val cache = new FetchSessionCache(10, 1000)
+    val fetchManager = new FetchManager(time, cache)
+
+    // Create a new fetch session with foo-0 and foo-1
+    val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100))
+    reqData1.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100))
+    val context1 = fetchManager.newContext(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false)
+    assertEquals(classOf[FullFetchContext], context1.getClass)
+    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
+    respData1.put(new TopicPartition("foo", 0), new FetchResponse.PartitionData(
+      Errors.NONE, 100, 100, 100, null, null))
+    respData1.put(new TopicPartition("foo", 1), new FetchResponse.PartitionData(
+      Errors.NONE, 10, 10, 10, null, null))
+    val resp1 = context1.updateAndGenerateResponseData(respData1)
+    assertEquals(Errors.NONE, resp1.error())
+    assertTrue(resp1.sessionId() != INVALID_SESSION_ID)
+    assertEquals(2, resp1.responseData().size())
+
+    // Create an incremental fetch request that removes foo-0 and adds bar-0
+    val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+    reqData2.put(new TopicPartition("bar", 0), new FetchRequest.PartitionData(15, 0, 0))
+    val removed2 = new util.ArrayList[TopicPartition]
+    removed2.add(new TopicPartition("foo", 0))
+    val context2 = fetchManager.newContext(
+      new JFetchMetadata(resp1.sessionId(), 1), reqData2, removed2, false)
+    assertEquals(classOf[IncrementalFetchContext], context2.getClass)
+    val parts2 = Set(new TopicPartition("foo", 1), new TopicPartition("bar", 0))
+    val reqData2Iter = parts2.iterator
+    context2.foreachPartition((topicPart, data) => {
+      assertEquals(reqData2Iter.next(), topicPart)
+    })
+    assertEquals(None, context2.getFetchOffset(new TopicPartition("foo", 0)))
+    assertEquals(10, context2.getFetchOffset(new TopicPartition("foo", 1)).get)
+    assertEquals(15, context2.getFetchOffset(new TopicPartition("bar", 0)).get)
+    assertEquals(None, context2.getFetchOffset(new TopicPartition("bar", 2)))
+    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
+    respData2.put(new TopicPartition("foo", 1), new FetchResponse.PartitionData(
+      Errors.NONE, 10, 10, 10, null, null))
+    respData2.put(new TopicPartition("bar", 0), new FetchResponse.PartitionData(
+      Errors.NONE, 10, 10, 10, null, null))
+    val resp2 = context2.updateAndGenerateResponseData(respData2)
+    assertEquals(Errors.NONE, resp2.error())
+    assertEquals(1, resp2.responseData().size())
+    assertTrue(resp2.sessionId() > 0)
+  }
+
+  @Test
+  def testZeroSizeFetchSession(): Unit = {
+    val time = new MockTime()
+    val cache = new FetchSessionCache(10, 1000)
+    val fetchManager = new FetchManager(time, cache)
+
+    // Create a new fetch session with foo-0 and foo-1
+    val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100))
+    reqData1.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100))
+    val context1 = fetchManager.newContext(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false)
+    assertEquals(classOf[FullFetchContext], context1.getClass)
+    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
+    respData1.put(new TopicPartition("foo", 0), new FetchResponse.PartitionData(
+      Errors.NONE, 100, 100, 100, null, null))
+    respData1.put(new TopicPartition("foo", 1), new FetchResponse.PartitionData(
+      Errors.NONE, 10, 10, 10, null, null))
+    val resp1 = context1.updateAndGenerateResponseData(respData1)
+    assertEquals(Errors.NONE, resp1.error())
+    assertTrue(resp1.sessionId() != INVALID_SESSION_ID)
+    assertEquals(2, resp1.responseData().size())
+
+    // Create an incremental fetch request that removes foo-0 and foo-1
+    // Verify that the previous fetch session was closed.
+    val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+    val removed2 = new util.ArrayList[TopicPartition]
+    removed2.add(new TopicPartition("foo", 0))
+    removed2.add(new TopicPartition("foo", 1))
+    val context2 = fetchManager.newContext(
+      new JFetchMetadata(resp1.sessionId(), 1), reqData2, removed2, false)
+    assertEquals(classOf[SessionlessFetchContext], context2.getClass)
+    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]
+    val resp2 = context2.updateAndGenerateResponseData(respData2)
+    assertEquals(INVALID_SESSION_ID, resp2.sessionId())
+    assertTrue(resp2.responseData().isEmpty)
+    assertEquals(0, cache.size())
+  }
+}
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 8e907d9..5de978c 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -69,6 +69,7 @@ class KafkaApisTest {
   private val clientRequestQuotaManager = EasyMock.createNiceMock(classOf[ClientRequestQuotaManager])
   private val replicaQuotaManager = EasyMock.createNiceMock(classOf[ReplicationQuotaManager])
   private val quotas = QuotaManagers(clientQuotaManager, clientQuotaManager, clientRequestQuotaManager, replicaQuotaManager, replicaQuotaManager, replicaQuotaManager)
+  private val fetchManager = EasyMock.createNiceMock(classOf[FetchManager])
   private val brokerTopicStats = new BrokerTopicStats
   private val clusterId = "clusterId"
   private val time = new MockTime
@@ -96,6 +97,7 @@ class KafkaApisTest {
       metrics,
       authorizer,
       quotas,
+      fetchManager,
       brokerTopicStats,
       clusterId,
       time,
diff --git a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
index 0692afb..1f5bec1 100644
--- a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
+++ b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
@@ -20,10 +20,10 @@ import kafka.cluster.BrokerEndPoint
 import kafka.server.BlockingSend
 import org.apache.kafka.clients.{ClientRequest, ClientResponse, MockClient}
 import org.apache.kafka.common.{Node, TopicPartition}
-import org.apache.kafka.common.protocol.ApiKeys
+import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.requests.AbstractRequest.Builder
 import org.apache.kafka.common.requests.FetchResponse.PartitionData
-import org.apache.kafka.common.requests.{AbstractRequest, EpochEndOffset, FetchResponse, OffsetsForLeaderEpochResponse}
+import org.apache.kafka.common.requests.{AbstractRequest, EpochEndOffset, FetchResponse, OffsetsForLeaderEpochResponse, FetchMetadata => JFetchMetadata}
 import org.apache.kafka.common.utils.{SystemTime, Time}
 
 /**
@@ -54,7 +54,8 @@ class ReplicaFetcherMockBlockingSend(offsets: java.util.Map[TopicPartition, Epoc
 
       case ApiKeys.FETCH =>
         fetchCount += 1
-        new FetchResponse(new java.util.LinkedHashMap[TopicPartition, PartitionData], 0)
+        new FetchResponse(Errors.NONE, new java.util.LinkedHashMap[TopicPartition, PartitionData], 0,
+          JFetchMetadata.INVALID_SESSION_ID)
 
       case _ =>
         throw new UnsupportedOperationException

-- 
To stop receiving notification emails like this one, please contact
junrao@apache.org.