You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ab...@apache.org on 2022/12/06 10:02:03 UTC

[kafka] branch trunk updated: KAFKA-13602: Adding ability to multicast records (#12803)

This is an automated email from the ASF dual-hosted git repository.

ableegoldman pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 77e294e7fca KAFKA-13602: Adding ability to multicast records (#12803)
77e294e7fca is described below

commit 77e294e7fca31e4e384930aa0c26431cfcc13410
Author: vamossagar12 <sa...@gmail.com>
AuthorDate: Tue Dec 6 15:31:38 2022 +0530

    KAFKA-13602: Adding ability to multicast records (#12803)
    
    This PR implements KIP-837 which enhances StreamPartitioner to multicast records.
    
    Reviewers: Anna Sophie Blee-Goldman <ab...@apache.org>, YEONCHEOL JANG
---
 .../org/apache/kafka/streams/KeyQueryMetadata.java |  24 +-
 .../streams/kstream/internals/KTableImpl.java      |  30 +-
 .../internals/WindowedStreamPartitioner.java       |   1 +
 .../kafka/streams/processor/StreamPartitioner.java |  23 ++
 .../internals/DefaultStreamPartitioner.java        |   1 +
 .../processor/internals/RecordCollectorImpl.java   |  24 +-
 .../processor/internals/StreamsMetadataState.java  |  39 ++-
 .../KStreamRepartitionIntegrationTest.java         | 105 +++++-
 ...yInnerJoinCustomPartitionerIntegrationTest.java |  90 +++++
 .../integration/StoreQueryIntegrationTest.java     |  53 +++
 .../integration/utils/IntegrationTestUtils.java    |   1 +
 .../kstream/internals/KStreamRepartitionTest.java  |  10 +-
 .../internals/WindowedStreamPartitionerTest.java   |   1 +
 .../processor/internals/ProcessorTopologyTest.java |  41 +++
 .../processor/internals/RecordCollectorTest.java   | 390 +++++++++++++++++++++
 .../internals/StreamsMetadataStateTest.java        |  42 ++-
 16 files changed, 837 insertions(+), 38 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java b/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java
index 9ca495214d6..6461ee7423f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java
@@ -43,10 +43,20 @@ public class KeyQueryMetadata {
 
     private final int partition;
 
+    private final Set<Integer> partitions;
+
     public KeyQueryMetadata(final HostInfo activeHost, final Set<HostInfo> standbyHosts, final int partition) {
         this.activeHost = activeHost;
         this.standbyHosts = standbyHosts;
         this.partition = partition;
+        this.partitions = Collections.singleton(partition);
+    }
+
+    public KeyQueryMetadata(final HostInfo activeHost, final Set<HostInfo> standbyHosts, final Set<Integer> partitions) {
+        this.activeHost = activeHost;
+        this.standbyHosts = standbyHosts;
+        this.partition = partitions.size() == 1 ? partitions.iterator().next() : -1;
+        this.partitions = partitions;
     }
 
     /**
@@ -109,6 +119,16 @@ public class KeyQueryMetadata {
         return partition;
     }
 
+    /**
+     * Get the store partitions corresponding to the key.
+     * A Key can be on multiple partitions if it has been
+     * multicasted using StreamPartitioner#partitions method
+     * @return store partition number
+     */
+    public Set<Integer> partitions() {
+        return partitions;
+    }
+
     @Override
     public boolean equals(final Object obj) {
         if (!(obj instanceof KeyQueryMetadata)) {
@@ -117,7 +137,8 @@ public class KeyQueryMetadata {
         final KeyQueryMetadata keyQueryMetadata = (KeyQueryMetadata) obj;
         return Objects.equals(keyQueryMetadata.activeHost, activeHost)
             && Objects.equals(keyQueryMetadata.standbyHosts, standbyHosts)
-            && Objects.equals(keyQueryMetadata.partition, partition);
+            && (Objects.equals(keyQueryMetadata.partition, partition)
+                || Objects.equals(keyQueryMetadata.partitions, partitions));
     }
 
     @Override
@@ -126,6 +147,7 @@ public class KeyQueryMetadata {
                 "activeHost=" + activeHost +
                 ", standbyHosts=" + standbyHosts +
                 ", partition=" + partition +
+                ", partitions=" + partitions +
                 '}';
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java
index 2abe7f5386b..e34ac2f5841 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java
@@ -78,6 +78,7 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.Optional;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
@@ -1046,7 +1047,18 @@ public class KTableImpl<K, S, V> extends AbstractStream<K, V> implements KTable<
         return doJoinOnForeignKey(other, foreignKeyExtractor, joiner, TableJoined.with(null, null), materialized, true);
     }
 
-    @SuppressWarnings("unchecked")
+    private final Function<Optional<Set<Integer>>, Integer> getPartition = maybeMulticastPartitions -> {
+        if (!maybeMulticastPartitions.isPresent()) {
+            return null;
+        }
+        if (maybeMulticastPartitions.get().size() != 1) {
+            throw new IllegalArgumentException("The partitions returned by StreamPartitioner#partitions method when used for FK join should be a singleton set");
+        }
+        return maybeMulticastPartitions.get().iterator().next();
+    };
+
+
+    @SuppressWarnings({"unchecked", "deprecation"})
     private <VR, KO, VO> KTable<K, VR> doJoinOnForeignKey(final KTable<KO, VO> foreignKeyTable,
                                                           final Function<V, KO> foreignKeyExtractor,
                                                           final ValueJoiner<V, VO, VR> joiner,
@@ -1069,6 +1081,7 @@ public class KTableImpl<K, S, V> extends AbstractStream<K, V> implements KTable<
         enableSendingOldValues(true);
 
         final TableJoinedInternal<K, KO> tableJoinedInternal = new TableJoinedInternal<>(tableJoined);
+
         final NamedInternal renamed = new NamedInternal(tableJoinedInternal.name());
 
         final String subscriptionTopicName = renamed.suffixWithOrElseGet(
@@ -1118,12 +1131,10 @@ public class KTableImpl<K, S, V> extends AbstractStream<K, V> implements KTable<
         );
         builder.addGraphNode(graphNode, subscriptionNode);
 
-
         final StreamPartitioner<KO, SubscriptionWrapper<K>> subscriptionSinkPartitioner =
-            tableJoinedInternal.otherPartitioner() == null
-                ? null
-                : (topic, key, val, numPartitions) ->
-                    tableJoinedInternal.otherPartitioner().partition(topic, key, null, numPartitions);
+                tableJoinedInternal.otherPartitioner() == null
+                        ? null
+                        : (topic, key, val, numPartitions) -> getPartition.apply(tableJoinedInternal.otherPartitioner().partitions(topic, key, null, numPartitions));
 
         final StreamSinkNode<KO, SubscriptionWrapper<K>> subscriptionSink = new StreamSinkNode<>(
             renamed.suffixWithOrElseGet("-subscription-registration-sink", builder, SINK_NAME),
@@ -1196,10 +1207,9 @@ public class KTableImpl<K, S, V> extends AbstractStream<K, V> implements KTable<
         builder.internalTopologyBuilder.addInternalTopic(finalRepartitionTopicName, InternalTopicProperties.empty());
 
         final StreamPartitioner<K, SubscriptionResponseWrapper<VO>> foreignResponseSinkPartitioner =
-            tableJoinedInternal.partitioner() == null
-                ? (topic, key, subscriptionResponseWrapper, numPartitions) -> subscriptionResponseWrapper.getPrimaryPartition()
-                : (topic, key, val, numPartitions) ->
-                    tableJoinedInternal.partitioner().partition(topic, key, null, numPartitions);
+                tableJoinedInternal.partitioner() == null
+                        ? (topic, key, subscriptionResponseWrapper, numPartitions) -> subscriptionResponseWrapper.getPrimaryPartition()
+                        : (topic, key, val, numPartitions) -> getPartition.apply(tableJoinedInternal.partitioner().partitions(topic, key, null, numPartitions));
 
         final StreamSinkNode<K, SubscriptionResponseWrapper<VO>> foreignResponseSink =
             new StreamSinkNode<>(
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java
index d68a52b8d02..f1ea71981bf 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java
@@ -40,6 +40,7 @@ public class WindowedStreamPartitioner<K, V> implements StreamPartitioner<Window
      * @return an integer between 0 and {@code numPartitions-1}, or {@code null} if the default partitioning logic should be used
      */
     @Override
+    @Deprecated
     public Integer partition(final String topic, final Windowed<K> windowedKey, final V value, final int numPartitions) {
         // for windowed key, the key bytes should never be null
         final byte[] keyBytes = serializer.serializeBaseKey(topic, windowedKey);
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java b/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java
index 90ffa3a4a83..b4c2483db7d 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java
@@ -18,6 +18,10 @@ package org.apache.kafka.streams.processor;
 
 import org.apache.kafka.streams.Topology;
 
+import java.util.Collections;
+import java.util.Optional;
+import java.util.Set;
+
 /**
  * Determine how records are distributed among the partitions in a Kafka topic. If not specified, the underlying producer's
  * {@link org.apache.kafka.clients.producer.internals.DefaultPartitioner} will be used to determine the partition.
@@ -58,5 +62,24 @@ public interface StreamPartitioner<K, V> {
      * @param numPartitions the total number of partitions
      * @return an integer between 0 and {@code numPartitions-1}, or {@code null} if the default partitioning logic should be used
      */
+    @Deprecated
     Integer partition(String topic, K key, V value, int numPartitions);
+
+    /**
+     * Determine the number(s) of the partition(s) to which a record with the given key and value should be sent, 
+     * for the given topic and current partition count
+     * @param topic the topic name this record is sent to
+     * @param key the key of the record
+     * @param value the value of the record
+     * @param numPartitions the total number of partitions
+     * @return an Optional of Set of integers between 0 and {@code numPartitions-1},
+     * Empty optional means using default partitioner
+     * Optional of an empty set means the record won't be sent to any partitions i.e drop it.
+     * Optional of Set of integers means the partitions to which the record should be sent to.
+     * */
+    default Optional<Set<Integer>> partitions(String topic, K key, V value, int numPartitions) {
+        final Integer partition = partition(topic, key, value, numPartitions);
+        return partition == null ? Optional.empty() : Optional.of(Collections.singleton(partition));
+    }
+
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java
index c7d909c65a3..d51b9791291 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java
@@ -29,6 +29,7 @@ public class DefaultStreamPartitioner<K, V> implements StreamPartitioner<K, V> {
     }
 
     @Override
+    @Deprecated
     public Integer partition(final String topic, final K key, final V value, final int numPartitions) {
         final byte[] keyBytes = keySerializer.serialize(topic, key);
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
index e42dc4b5735..43c329896f6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
@@ -50,6 +50,8 @@ import org.apache.kafka.streams.processor.internals.metrics.TopicMetrics;
 
 import org.slf4j.Logger;
 
+import java.util.Set;
+import java.util.Optional;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -130,7 +132,6 @@ public class RecordCollectorImpl implements RecordCollector {
                             final String processorNodeId,
                             final InternalProcessorContext<Void, Void> context,
                             final StreamPartitioner<? super K, ? super V> partitioner) {
-        final Integer partition;
 
         if (partitioner != null) {
             final List<PartitionInfo> partitions;
@@ -150,16 +151,30 @@ public class RecordCollectorImpl implements RecordCollector {
                 );
             }
             if (partitions.size() > 0) {
-                partition = partitioner.partition(topic, key, value, partitions.size());
+                final Optional<Set<Integer>> maybeMulticastPartitions = partitioner.partitions(topic, key, value, partitions.size());
+                if (!maybeMulticastPartitions.isPresent()) {
+                    // A null//empty partition indicates we should use the default partitioner
+                    send(topic, key, value, headers, null, timestamp, keySerializer, valueSerializer, processorNodeId, context);
+                } else {
+                    final Set<Integer> multicastPartitions = maybeMulticastPartitions.get();
+                    if (multicastPartitions.isEmpty()) {
+                        // If a record is not to be sent to any partition, mark it as a dropped record.
+                        log.debug("Not sending the record with key {} , value {} to any partition", key, value);
+                        droppedRecordsSensor.record();
+                    } else {
+                        for (final int multicastPartition: multicastPartitions) {
+                            send(topic, key, value, headers, multicastPartition, timestamp, keySerializer, valueSerializer, processorNodeId, context);
+                        }
+                    }
+                }
             } else {
                 throw new StreamsException("Could not get partition information for topic " + topic + " for task " + taskId +
                     ". This can happen if the topic does not exist.");
             }
         } else {
-            partition = null;
+            send(topic, key, value, headers, null, timestamp, keySerializer, valueSerializer, processorNodeId, context);
         }
 
-        send(topic, key, value, headers, partition, timestamp, keySerializer, valueSerializer, processorNodeId, context);
     }
 
     @Override
@@ -212,6 +227,7 @@ public class RecordCollectorImpl implements RecordCollector {
 
             if (exception == null) {
                 final TopicPartition tp = new TopicPartition(metadata.topic(), metadata.partition());
+                log.info("Produced key:{}, value:{} successfully to tp:{}", key, value, tp);
                 if (metadata.offset() >= 0L) {
                     offsets.put(tp, metadata.offset());
                 } else {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java
index 61add951b22..7217666bcf5 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java
@@ -43,6 +43,8 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 
+import static org.apache.kafka.clients.producer.RecordMetadata.UNKNOWN_PARTITION;
+
 /**
  * Provides access to the {@link StreamsMetadata} in a KafkaStreams application. This can be used
  * to discover the locations of {@link org.apache.kafka.streams.processor.StateStore}s
@@ -262,9 +264,9 @@ public class StreamsMetadataState {
             // global stores are on every node. if we don't have the host info
             // for this host then just pick the first metadata
             if (thisHost.equals(UNKNOWN_HOST)) {
-                return new KeyQueryMetadata(allMetadata.get(0).hostInfo(), Collections.emptySet(), -1);
+                return new KeyQueryMetadata(allMetadata.get(0).hostInfo(), Collections.emptySet(), UNKNOWN_PARTITION);
             }
-            return new KeyQueryMetadata(localMetadata.get().hostInfo(), Collections.emptySet(), -1);
+            return new KeyQueryMetadata(localMetadata.get().hostInfo(), Collections.emptySet(), UNKNOWN_PARTITION);
         }
 
         final SourceTopicsInfo sourceTopicsInfo = getSourceTopicsInfo(storeName);
@@ -464,10 +466,20 @@ public class StreamsMetadataState {
                                                            final StreamPartitioner<? super K, ?> partitioner,
                                                            final SourceTopicsInfo sourceTopicsInfo) {
 
-        final Integer partition = partitioner.partition(sourceTopicsInfo.topicWithMostPartitions, key, null, sourceTopicsInfo.maxPartitions);
+        // Making an assumption that the partitions method won't return an empty Optional set
+        // which means it is not intended to use the default partitioner. It is an optimistic
+        // assumption, but the older implementation with partition() also made the same assumption.
+        final Set<Integer> partitions = partitioner.partitions(sourceTopicsInfo.topicWithMostPartitions, key, null, sourceTopicsInfo.maxPartitions).get();
+        // The record was dropped and hence won't be found anywhere
+        if (partitions.isEmpty()) {
+            return new KeyQueryMetadata(UNKNOWN_HOST, Collections.emptySet(), UNKNOWN_PARTITION);
+        }
+
         final Set<TopicPartition> matchingPartitions = new HashSet<>();
         for (final String sourceTopic : sourceTopicsInfo.sourceTopics) {
-            matchingPartitions.add(new TopicPartition(sourceTopic, partition));
+            for (final Integer partition : partitions) {
+                matchingPartitions.add(new TopicPartition(sourceTopic, partition));
+            }
         }
 
         HostInfo activeHost = UNKNOWN_HOST;
@@ -489,7 +501,7 @@ public class StreamsMetadataState {
             }
         }
 
-        return new KeyQueryMetadata(activeHost, standbyHosts, partition);
+        return new KeyQueryMetadata(activeHost, standbyHosts, partitions);
     }
 
     private <K> KeyQueryMetadata getKeyQueryMetadataForKey(final String storeName,
@@ -498,10 +510,21 @@ public class StreamsMetadataState {
                                                            final SourceTopicsInfo sourceTopicsInfo,
                                                            final String topologyName) {
         Objects.requireNonNull(topologyName, "topology name must not be null");
-        final Integer partition = partitioner.partition(sourceTopicsInfo.topicWithMostPartitions, key, null, sourceTopicsInfo.maxPartitions);
+
+        // Making an assumption that the partitions method won't return an empty Optional set
+        // which means it is not intended to use the default partitioner. It is an optimistic
+        // assumption, but the older implementation with partition() also made the same assumption.
+        final Set<Integer> partitions = partitioner.partitions(sourceTopicsInfo.topicWithMostPartitions, key, null, sourceTopicsInfo.maxPartitions).get();
+        // The record was dropped and hence won't be found anywhere
+        if (partitions.isEmpty()) {
+            return new KeyQueryMetadata(UNKNOWN_HOST, Collections.emptySet(), UNKNOWN_PARTITION);
+        }
+
         final Set<TopicPartition> matchingPartitions = new HashSet<>();
         for (final String sourceTopic : sourceTopicsInfo.sourceTopics) {
-            matchingPartitions.add(new TopicPartition(sourceTopic, partition));
+            for (final Integer partition : partitions) {
+                matchingPartitions.add(new TopicPartition(sourceTopic, partition));
+            }
         }
 
         HostInfo activeHost = UNKNOWN_HOST;
@@ -527,7 +550,7 @@ public class StreamsMetadataState {
             }
         }
 
-        return new KeyQueryMetadata(activeHost, standbyHosts, partition);
+        return new KeyQueryMetadata(activeHost, standbyHosts, partitions);
     }
 
     private SourceTopicsInfo getSourceTopicsInfo(final String storeName) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java
index 1cff0fc2016..78734b68bb3 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java
@@ -39,6 +39,7 @@ import org.apache.kafka.streams.kstream.JoinWindows;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.Named;
 import org.apache.kafka.streams.kstream.Repartitioned;
+import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.test.IntegrationTest;
 import org.apache.kafka.test.TestUtils;
 import org.junit.After;
@@ -65,12 +66,15 @@ import java.util.List;
 import java.util.Objects;
 import java.util.Properties;
 import java.util.Set;
+import java.util.Optional;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import static org.apache.kafka.streams.KafkaStreams.State.ERROR;
 import static org.apache.kafka.streams.KafkaStreams.State.REBALANCING;
@@ -107,6 +111,8 @@ public class KStreamRepartitionIntegrationTest {
     private String outputTopic;
     private String applicationId;
 
+    private String safeTestName;
+
     private Properties streamsConfiguration;
     private List<KafkaStreams> kafkaStreamsInstances;
 
@@ -129,7 +135,7 @@ public class KStreamRepartitionIntegrationTest {
         streamsConfiguration = new Properties();
         kafkaStreamsInstances = new ArrayList<>();
 
-        final String safeTestName = safeUniqueTestName(getClass(), testName);
+        safeTestName = safeUniqueTestName(getClass(), testName);
 
         topicB = "topic-b-" + safeTestName;
         inputTopic = "input-topic-" + safeTestName;
@@ -293,6 +299,80 @@ public class KStreamRepartitionIntegrationTest {
         );
     }
 
+    @Test
+    public void shouldRepartitionToMultiplePartitions() throws Exception {
+        final String repartitionName = "broadcasting-partitioner-test";
+        final long timestamp = System.currentTimeMillis();
+        final AtomicInteger partitionerInvocation = new AtomicInteger(0);
+
+        // This test needs to write to an output topic with 4 partitions. Hence, creating a new one
+        final String broadcastingOutputTopic = "broadcast-output-topic-" + safeTestName;
+        CLUSTER.createTopic(broadcastingOutputTopic, 4, 1);
+
+        final List<KeyValue<Integer, String>> expectedRecordsOnRepartition = Arrays.asList(
+            new KeyValue<>(1, "A"),
+            new KeyValue<>(1, "A"),
+            new KeyValue<>(1, "A"),
+            new KeyValue<>(1, "A"),
+            new KeyValue<>(2, "B"),
+            new KeyValue<>(2, "B"),
+            new KeyValue<>(2, "B"),
+            new KeyValue<>(2, "B")
+        );
+
+        final List<KeyValue<Integer, String>> expectedRecords = expectedRecordsOnRepartition.subList(3, 5);
+
+        class BroadcastingPartitioner implements StreamPartitioner<Integer, String> {
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final Integer key, final String value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final Integer key, final String value, final int numPartitions) {
+                partitionerInvocation.incrementAndGet();
+                return Optional.of(IntStream.range(0, numPartitions).boxed().collect(Collectors.toSet()));
+            }
+        }
+
+        sendEvents(timestamp, expectedRecords);
+
+        final StreamsBuilder builder = new StreamsBuilder();
+
+        final Repartitioned<Integer, String> repartitioned = Repartitioned
+            .<Integer, String>as(repartitionName)
+            .withStreamPartitioner(new BroadcastingPartitioner());
+
+        builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String()))
+            .repartition(repartitioned)
+            .to(broadcastingOutputTopic);
+
+        startStreams(builder);
+
+        final String topic = toRepartitionTopicName(repartitionName);
+
+        // Both records should be there on all 4 partitions of repartition and output topic
+        validateReceivedMessages(
+            new IntegerDeserializer(),
+            new StringDeserializer(),
+            expectedRecordsOnRepartition,
+            topic
+        );
+
+
+        validateReceivedMessages(
+            new IntegerDeserializer(),
+            new StringDeserializer(),
+            expectedRecordsOnRepartition,
+            broadcastingOutputTopic
+        );
+
+        assertTrue(topicExists(topic));
+        assertEquals(expectedRecords.size(), partitionerInvocation.get());
+    }
+
+
     @Test
     public void shouldUseStreamPartitionerForRepartitionOperation() throws Exception {
         final int partition = 1;
@@ -799,24 +879,33 @@ public class KStreamRepartitionIntegrationTest {
                                                  final Deserializer<V> valueSerializer,
                                                  final List<KeyValue<K, V>> expectedRecords) throws Exception {
 
+        validateReceivedMessages(keySerializer, valueSerializer, expectedRecords, outputTopic);
+    }
+
+    private <K, V> void validateReceivedMessages(final Deserializer<K> keySerializer,
+                                                 final Deserializer<V> valueSerializer,
+                                                 final List<KeyValue<K, V>> expectedRecords,
+                                                 final String outputTopic) throws Exception {
+
         final String safeTestName = safeUniqueTestName(getClass(), testName);
         final Properties consumerProperties = new Properties();
         consumerProperties.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
         consumerProperties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group-" + safeTestName);
         consumerProperties.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
         consumerProperties.setProperty(
-            ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG,
-            keySerializer.getClass().getName()
+                ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG,
+                keySerializer.getClass().getName()
         );
         consumerProperties.setProperty(
-            ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
-            valueSerializer.getClass().getName()
+                ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
+                valueSerializer.getClass().getName()
         );
 
         IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived(
-            consumerProperties,
-            outputTopic,
-            expectedRecords
+                consumerProperties,
+                outputTopic,
+                expectedRecords
         );
     }
+
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java
index b5eb98a31a1..1a9e4635bb1 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java
@@ -28,6 +28,8 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Properties;
 import java.util.Set;
+import java.util.Optional;
+import java.util.Arrays;
 
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.producer.ProducerConfig;
@@ -39,6 +41,7 @@ import org.apache.kafka.streams.KafkaStreams;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.Consumed;
@@ -49,6 +52,7 @@ import org.apache.kafka.streams.kstream.Produced;
 import org.apache.kafka.streams.kstream.Repartitioned;
 import org.apache.kafka.streams.kstream.TableJoined;
 import org.apache.kafka.streams.kstream.ValueJoiner;
+import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.utils.UniqueTopicSerdeScope;
 import org.apache.kafka.test.TestUtils;
@@ -61,6 +65,7 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Tag;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.api.Disabled;
 
 @Timeout(600)
 @Tag("integration")
@@ -83,6 +88,20 @@ public class KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest {
     private final static Properties PRODUCER_CONFIG_1 = new Properties();
     private final static Properties PRODUCER_CONFIG_2 = new Properties();
 
+    static class MultiPartitioner implements StreamPartitioner<String, Void> {
+
+        @Override
+        @Deprecated
+        public Integer partition(final String topic, final String key, final Void value, final int numPartitions) {
+            return null;
+        }
+
+        @Override
+        public Optional<Set<Integer>> partitions(final String topic, final String key, final Void value, final int numPartitions) {
+            return Optional.of(new HashSet<>(Arrays.asList(0, 1, 2)));
+        }
+    }
+
     @BeforeAll
     public static void startCluster() throws IOException, InterruptedException {
         CLUSTER.start();
@@ -163,6 +182,35 @@ public class KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest {
         verifyKTableKTableJoin(expectedOne);
     }
 
+    @Disabled("This test works individually but fails when run along with the class. Ignoring for now.")
+    @Test
+    public void shouldThrowIllegalArgumentExceptionWhenCustomPartionerReturnsMultiplePartitions() throws Exception {
+        final String innerJoinType = "INNER";
+        final String queryableName = innerJoinType + "-store1";
+
+        streams = prepareTopologyWithNonSingletonPartitions(queryableName, streamsConfig);
+        streamsTwo = prepareTopologyWithNonSingletonPartitions(queryableName, streamsConfigTwo);
+        streamsThree = prepareTopologyWithNonSingletonPartitions(queryableName, streamsConfigThree);
+
+        final List<KafkaStreams> kafkaStreamsList = asList(streams, streamsTwo, streamsThree);
+
+        for (final KafkaStreams stream: kafkaStreamsList) {
+            stream.setUncaughtExceptionHandler(e -> {
+                assertEquals("The partitions returned by StreamPartitioner#partitions method when used for FK join should be a singleton set", e.getCause().getMessage());
+                return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT;
+            });
+        }
+
+        startApplicationAndWaitUntilRunning(kafkaStreamsList, ofSeconds(120));
+
+        // Sleeping to let the processing happen inducing the failure
+        Thread.sleep(60000);
+
+        assertEquals(KafkaStreams.State.ERROR, streams.state());
+        assertEquals(KafkaStreams.State.ERROR, streamsTwo.state());
+        assertEquals(KafkaStreams.State.ERROR, streamsThree.state());
+    }
+
     private void verifyKTableKTableJoin(final Set<KeyValue<String, String>> expectedResult) throws Exception {
         final String innerJoinType = "INNER";
         final String queryableName = innerJoinType + "-store1";
@@ -235,6 +283,48 @@ public class KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest {
         return new KafkaStreams(builder.build(streamsConfig), streamsConfig);
     }
 
+    private static KafkaStreams prepareTopologyWithNonSingletonPartitions(final String queryableName, final Properties streamsConfig) {
+
+        final UniqueTopicSerdeScope serdeScope = new UniqueTopicSerdeScope();
+        final StreamsBuilder builder = new StreamsBuilder();
+
+        final KTable<String, String> table1 = builder.stream(TABLE_1,
+                        Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)))
+                .repartition(repartitionA())
+                .toTable(Named.as("table.a"));
+
+        final KTable<String, String> table2 = builder
+                .stream(TABLE_2,
+                        Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)))
+                .repartition(repartitionB())
+                .toTable(Named.as("table.b"));
+
+        final Materialized<String, String, KeyValueStore<Bytes, byte[]>> materialized;
+        if (queryableName != null) {
+            materialized = Materialized.<String, String, KeyValueStore<Bytes, byte[]>>as(queryableName)
+                    .withKeySerde(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true))
+                    .withValueSerde(serdeScope.decorateSerde(Serdes.String(), streamsConfig, false))
+                    .withCachingDisabled();
+        } else {
+            throw new RuntimeException("Current implementation of joinOnForeignKey requires a materialized store");
+        }
+
+        final ValueJoiner<String, String, String> joiner = (value1, value2) -> "value1=" + value1 + ",value2=" + value2;
+
+        final TableJoined<String, String> tableJoined = TableJoined.with(
+                new MultiPartitioner(),
+                (topic, key, value, numPartitions) -> Math.abs(key.hashCode()) % numPartitions
+        );
+
+        table1.join(table2, KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest::getKeyB, joiner, tableJoined, materialized)
+                .toStream()
+                .to(OUTPUT,
+                        Produced.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true),
+                                serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)));
+
+        return new KafkaStreams(builder.build(streamsConfig), streamsConfig);
+    }
+
     private static Repartitioned<String, String> repartitionA() {
         final Repartitioned<String, String> repartitioned = Repartitioned.as("a");
         return repartitioned.withKeySerde(Serdes.String()).withValueSerde(Serdes.String())
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java
index 85595cefc3f..44e690b3000 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java
@@ -26,6 +26,7 @@ import org.apache.kafka.streams.KafkaStreams;
 import org.apache.kafka.streams.KafkaStreams.State;
 import org.apache.kafka.streams.KeyQueryMetadata;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.KeyValueTimestamp;
 import org.apache.kafka.streams.StoreQueryParameters;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
@@ -34,6 +35,7 @@ import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.Consumed;
 import org.apache.kafka.streams.kstream.Materialized;
+import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.internals.namedtopology.KafkaStreamsNamedTopologyWrapper;
 import org.apache.kafka.streams.processor.internals.namedtopology.NamedTopologyBuilder;
 import org.apache.kafka.streams.processor.internals.namedtopology.NamedTopologyStoreQueryParameters;
@@ -65,8 +67,11 @@ import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+import java.util.Set;
+import java.util.Collections;
 
 import static java.util.Collections.singletonList;
+import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.getStore;
 import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName;
 import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning;
@@ -117,6 +122,54 @@ public class StoreQueryIntegrationTest {
         cluster.stop();
     }
 
+    @Test
+    public void shouldReturnAllPartitionsWhenRecordIsBroadcast() throws Exception {
+
+        class BroadcastingPartitioner implements StreamPartitioner<Integer, String> {
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final Integer key, final String value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final Integer key, final String value, final int numPartitions) {
+                return Optional.of(IntStream.range(0, numPartitions).boxed().collect(Collectors.toSet()));
+            }
+        }
+
+        final int batch1NumMessages = 1;
+        final int key = 1;
+        final Semaphore semaphore = new Semaphore(0);
+
+        final StreamsBuilder builder = new StreamsBuilder();
+        getStreamsBuilderWithTopology(builder, semaphore);
+
+        final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, streamsConfiguration());
+
+        startApplicationAndWaitUntilRunning(Collections.singletonList(kafkaStreams1), Duration.ofSeconds(60));
+
+        final Properties producerProps = new Properties();
+        producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers());
+        producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class);
+        producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class);
+
+        final List<KeyValueTimestamp<Integer, Integer>> records = Collections.singletonList(new KeyValueTimestamp<>(key, 0, mockTime.milliseconds()));
+
+        // Send the record to both partitions of INPUT_TOPIC_NAME.
+        IntegrationTestUtils.produceSynchronously(producerProps, false, INPUT_TOPIC_NAME, Optional.of(0), records);
+        IntegrationTestUtils.produceSynchronously(producerProps, false, INPUT_TOPIC_NAME, Optional.of(1), records);
+
+        assertThat(semaphore.tryAcquire(batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true)));
+
+        until(() -> {
+            final KeyQueryMetadata keyQueryMetadataFetched = kafkaStreams1.queryMetadataForKey(TABLE_NAME, key, new BroadcastingPartitioner());
+            assertThat(keyQueryMetadataFetched.activeHost().host(), is("localhost"));
+            assertThat(keyQueryMetadataFetched.partitions(), is(mkSet(0, 1)));
+            return true;
+        });
+    }
+
     @Test
     public void shouldQueryOnlyActivePartitionStoresByDefault() throws Exception {
         final int batch1NumMessages = 100;
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
index c14988cdae4..f0915c8b88e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
@@ -1283,6 +1283,7 @@ public class IntegrationTestUtils {
                                                                  final int maxMessages) {
         final List<ConsumerRecord<K, V>> consumerRecords;
         consumer.subscribe(Collections.singletonList(topic));
+        System.out.println("Got assignment:" + consumer.assignment());
         final int pollIntervalMs = 100;
         consumerRecords = new ArrayList<>();
         int totalPollTimeMs = 0;
diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java
index c7669a978a1..9dfabeacfe8 100644
--- a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java
@@ -45,6 +45,8 @@ import java.time.Instant;
 import java.util.Map;
 import java.util.Properties;
 import java.util.TreeMap;
+import java.util.Optional;
+import java.util.Collections;
 
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -78,8 +80,8 @@ public class KStreamRepartitionTest {
         @SuppressWarnings("unchecked")
         final StreamPartitioner<Integer, String> streamPartitionerMock = mock(StreamPartitioner.class);
 
-        when(streamPartitionerMock.partition(anyString(), eq(0), eq("X0"), anyInt())).thenReturn(1);
-        when(streamPartitionerMock.partition(anyString(), eq(1), eq("X1"), anyInt())).thenReturn(1);
+        when(streamPartitionerMock.partitions(anyString(), eq(0), eq("X0"), anyInt())).thenReturn(Optional.of(Collections.singleton(1)));
+        when(streamPartitionerMock.partitions(anyString(), eq(1), eq("X1"), anyInt())).thenReturn(Optional.of(Collections.singleton(1)));
 
         final String repartitionOperationName = "test";
         final Repartitioned<Integer, String> repartitioned = Repartitioned
@@ -111,8 +113,8 @@ public class KStreamRepartitionTest {
             assertTrue(testOutputTopic.readRecordsToList().isEmpty());
         }
 
-        verify(streamPartitionerMock).partition(anyString(), eq(0), eq("X0"), anyInt());
-        verify(streamPartitionerMock).partition(anyString(), eq(1), eq("X1"), anyInt());
+        verify(streamPartitionerMock).partitions(anyString(), eq(0), eq("X0"), anyInt());
+        verify(streamPartitionerMock).partitions(anyString(), eq(1), eq("X1"), anyInt());
     }
 
     @Test
diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java
index a6595257277..17ed8eac97a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java
@@ -72,6 +72,7 @@ public class WindowedStreamPartitionerTest {
                 final TimeWindow window = new TimeWindow(10 * w, 20 * w);
 
                 final Windowed<Integer> windowedKey = new Windowed<>(key, window);
+                @SuppressWarnings("deprecation")
                 final Integer actual = streamPartitioner.partition(topicName, windowedKey, value, infos.size());
 
                 assertEquals(expected, actual);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
index de26c7bfaed..e4d015c991a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
@@ -60,6 +60,8 @@ import java.util.Properties;
 import java.util.Set;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Optional;
+import java.util.HashSet;
 import java.util.function.Supplier;
 
 import static java.util.Arrays.asList;
@@ -298,6 +300,17 @@ public class ProcessorTopologyTest {
         assertTrue(outputTopic1.isEmpty());
     }
 
+    @Test
+    public void testDrivingSimpleTopologyWithDroppingPartitioner() {
+        driver = new TopologyTestDriver(createSimpleTopologyWithDroppingPartitioner(), props);
+        final TestInputTopic<String, String> inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER, Instant.ofEpochMilli(0L), Duration.ZERO);
+        final TestOutputTopic<String, String> outputTopic1 =
+                driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.String().deserializer(), Serdes.String().deserializer());
+
+        inputTopic.pipeInput("key1", "value1");
+        assertTrue(outputTopic1.isEmpty());
+    }
+
     @Test
     public void testDrivingStatefulTopology() {
         final String storeName = "entries";
@@ -1583,6 +1596,34 @@ public class ProcessorTopologyTest {
             .addSink("sink2", OUTPUT_TOPIC_2, constantPartitioner(partition), "child2");
     }
 
+    static class DroppingPartitioner implements StreamPartitioner<String, String> {
+
+        @Override
+        @Deprecated
+        public Integer partition(final String topic, final String key, final String value, final int numPartitions) {
+            return null;
+        }
+
+        @Override
+        public Optional<Set<Integer>> partitions(final String topic, final String key, final String value, final int numPartitions) {
+            final Set<Integer> partitions = new HashSet<>();
+            for (int i = 1; i < numPartitions; i += 2) {
+                partitions.add(i);
+            }
+            return Optional.of(partitions);
+        }
+    }
+
+    // Adding a test only for dropping partitioner as the output topic is a single partitioned topic
+    // and the default implementation of partitions method already sends a singleton list which is
+    // getting tested in other tests
+    private Topology createSimpleTopologyWithDroppingPartitioner() {
+        return topology
+                .addSource("source", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1)
+                .addProcessor("processor", ForwardingProcessor::new, "source")
+                .addSink("sink", OUTPUT_TOPIC_1, new DroppingPartitioner(), "processor");
+    }
+
     @Deprecated // testing old PAPI
     private Topology createStatefulTopology(final String storeName) {
         return topology
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
index b272ea609e1..9f7d39d25cf 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
@@ -67,6 +67,11 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.Optional;
+import java.util.Set;
+import java.util.HashSet;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
@@ -291,6 +296,391 @@ public class RecordCollectorTest {
         assertThrows(UnsupportedOperationException.class, () -> offsets.put(topicPartition, 50L));
     }
 
+    @Test
+    public void shouldSendOnlyToEvenPartitions() {
+        class EvenPartitioner implements StreamPartitioner<String, Object> {
+
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final String key, final Object value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final String key, final Object value, final int numPartitions) {
+                final Set<Integer> partitions = new HashSet<>();
+                for (int i = 0; i < numPartitions; i += 2) {
+                    partitions.add(i);
+                }
+                return Optional.of(partitions);
+            }
+        }
+
+        final EvenPartitioner evenPartitioner = new EvenPartitioner();
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                evenPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, null, context, evenPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertFalse(offsets.containsKey(new TopicPartition(topic, 1)));
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(18, mockProducer.history().size());
+
+        // returned offsets should not be modified
+        final TopicPartition topicPartition = new TopicPartition(topic, 0);
+        assertThrows(UnsupportedOperationException.class, () -> offsets.put(topicPartition, 50L));
+    }
+
+    @Test
+    public void shouldBroadcastToAllPartitions() {
+
+        class BroadcastingPartitioner implements StreamPartitioner<String, Object> {
+
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final String key, final Object value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final String key, final Object value, final int numPartitions) {
+                return Optional.of(IntStream.range(0, numPartitions).boxed().collect(Collectors.toSet()));
+            }
+        }
+
+        final BroadcastingPartitioner broadcastingPartitioner = new BroadcastingPartitioner();
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                broadcastingPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, null, context, broadcastingPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 1)));
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(27, mockProducer.history().size());
+
+        // returned offsets should not be modified
+        final TopicPartition topicPartition = new TopicPartition(topic, 0);
+        assertThrows(UnsupportedOperationException.class, () -> offsets.put(topicPartition, 50L));
+    }
+
+    @Test
+    public void shouldDropAllRecords() {
+
+        class DroppingPartitioner implements StreamPartitioner<String, Object> {
+
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final String key, final Object value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final String key, final Object value, final int numPartitions) {
+                return Optional.of(Collections.emptySet());
+            }
+        }
+
+        final DroppingPartitioner droppingPartitioner = new DroppingPartitioner();
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                droppingPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final String topic = "topic";
+
+        final Metric recordsDropped = streamsMetrics.metrics().get(new MetricName(
+                "dropped-records-total",
+                "stream-task-metrics",
+                "The total number of dropped records",
+                mkMap(
+                        mkEntry("thread-id", Thread.currentThread().getName()),
+                        mkEntry("task-id", taskId.toString())
+                )
+        ));
+
+
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, null, context, droppingPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+        assertTrue(offsets.isEmpty());
+
+        assertEquals(0, mockProducer.history().size());
+        assertThat(recordsDropped.metricValue(), equalTo(9.0));
+
+        // returned offsets should not be modified
+        final TopicPartition topicPartition = new TopicPartition(topic, 0);
+        assertThrows(UnsupportedOperationException.class, () -> offsets.put(topicPartition, 50L));
+    }
+
+    @Test
+    public void shouldUseDefaultPartitionerViaPartitions() {
+
+        class DefaultPartitioner implements StreamPartitioner<String, Object> {
+
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final String key, final Object value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final String key, final Object value, final int numPartitions) {
+                return Optional.empty();
+            }
+        }
+
+        final DefaultPartitioner defaultPartitioner = new DefaultPartitioner();
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                defaultPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final String topic = "topic";
+
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, null, context, defaultPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        // with mock producer without specific partition, we would use default producer partitioner with murmur hash
+        assertEquals(3L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1)));
+        assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(9, mockProducer.history().size());
+    }
+
+    @Test
+    public void shouldUseDefaultPartitionerAsPartitionReturnsNull() {
+
+        final StreamPartitioner<String, Object> streamPartitioner =
+                (topic, key, value, numPartitions) -> null;
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                streamPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final String topic = "topic";
+
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, null, context, streamPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        // with mock producer without specific partition, we would use default producer partitioner with murmur hash
+        assertEquals(3L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1)));
+        assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(9, mockProducer.history().size());
+    }
+
+    @Test
+    public void shouldUseDefaultPartitionerAsStreamPartitionerIsNull() {
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                streamPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final String topic = "topic";
+
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, context, null);
+        collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, null, context, null);
+        collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, null, context, null);
+        collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, null, context, null);
+        collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, null, context, null);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, null, context, null);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, null, context, null);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, null, context, null);
+        collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, null, context, null);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        // with mock producer without specific partition, we would use default producer partitioner with murmur hash
+        assertEquals(3L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1)));
+        assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(9, mockProducer.history().size());
+    }
+
     @Test
     public void shouldSendWithNoPartition() {
         final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java
index d44b79885b4..e04acca447a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java
@@ -45,6 +45,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.Optional;
+import java.util.HashSet;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -136,6 +138,23 @@ public class StreamsMetadataStateTest {
         storeNames = mkSet("table-one", "table-two", "merged-table", globalTable);
     }
 
+    static class MultiValuedPartitioner implements StreamPartitioner<String, Object> {
+
+        @Override
+        @Deprecated
+        public Integer partition(final String topic, final String key, final Object value, final int numPartitions) {
+            return null;
+        }
+
+        @Override
+        public Optional<Set<Integer>> partitions(final String topic, final String key, final Object value, final int numPartitions) {
+            final Set<Integer> partitions = new HashSet<>();
+            partitions.add(0);
+            partitions.add(1);
+            return Optional.of(partitions);
+        }
+    }
+
     @Test
     public void shouldNotThrowExceptionWhenOnChangeNotCalled() {
         final Collection<StreamsMetadata> metadata = new StreamsMetadataState(
@@ -229,7 +248,7 @@ public class StreamsMetadataStateTest {
         metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions,
             cluster.withPartitions(Collections.singletonMap(tp4, new PartitionInfo("topic-three", 1, null, null, null))));
 
-        final KeyQueryMetadata expected = new KeyQueryMetadata(hostThree, mkSet(hostTwo), 0);
+        final KeyQueryMetadata expected = new KeyQueryMetadata(hostThree, mkSet(hostTwo), Collections.singleton(0));
         final KeyQueryMetadata actual = metadataState.getKeyQueryMetadataForKey("table-three",
                                                                     "the-key",
                                                                     Serdes.String().serializer());
@@ -244,13 +263,30 @@ public class StreamsMetadataStateTest {
         metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions,
             cluster.withPartitions(Collections.singletonMap(tp4, new PartitionInfo("topic-three", 1, null, null, null))));
 
-        final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, Collections.emptySet(), 1);
+        final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, Collections.emptySet(), Collections.singleton(1));
 
         final KeyQueryMetadata actual = metadataState.getKeyQueryMetadataForKey("table-three",
                 "the-key",
                 partitioner);
         assertEquals(expected, actual);
         assertEquals(1, actual.partition());
+        assertEquals(Collections.singleton(1), actual.partitions());
+    }
+
+    @Test
+    public void shouldGetInstanceWithKeyAndCustomMulticastingPartitioner() {
+        final TopicPartition tp4 = new TopicPartition("topic-three", 0);
+        final TopicPartition tp5 = new TopicPartition("topic-three", 1);
+        hostToActivePartitions.put(hostTwo, mkSet(tp4, tp5));
+
+        final KeyQueryMetadata expected = new KeyQueryMetadata(hostThree, Collections.singleton(hostTwo), mkSet(0, 1));
+
+        final KeyQueryMetadata actual = metadataState.getKeyQueryMetadataForKey("table-three",
+                "the-key",
+                new MultiValuedPartitioner());
+        assertEquals(expected, actual);
+        assertEquals(-1, actual.partition());
+        assertEquals(mkSet(0, 1), actual.partitions());
     }
 
     @Test
@@ -268,7 +304,7 @@ public class StreamsMetadataStateTest {
         metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions,
                 cluster.withPartitions(Collections.singletonMap(topic2P2, new PartitionInfo("topic-two", 2, null, null, null))));
 
-        final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, mkSet(hostOne), 2);
+        final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, mkSet(hostOne), Collections.singleton(2));
 
         final KeyQueryMetadata actual = metadataState.getKeyQueryMetadataForKey("merged-table",  "the-key",
             (topic, key, value, numPartitions) -> 2);