This is an automated email from the ASF dual-hosted git repository.

ableegoldman pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 77e294e7fca KAFKA-13602: Adding ability to multicast records (#12803)
77e294e7fca is described below

commit 77e294e7fca31e4e384930aa0c26431cfcc13410
Author: vamossagar12 <[email protected]>
AuthorDate: Tue Dec 6 15:31:38 2022 +0530

    KAFKA-13602: Adding ability to multicast records (#12803)
    
    This PR implements KIP-837 which enhances StreamPartitioner to multicast 
records.
    
    Reviewers: Anna Sophie Blee-Goldman <[email protected]>, YEONCHEOL 
JANG
---
 .../org/apache/kafka/streams/KeyQueryMetadata.java |  24 +-
 .../streams/kstream/internals/KTableImpl.java      |  30 +-
 .../internals/WindowedStreamPartitioner.java       |   1 +
 .../kafka/streams/processor/StreamPartitioner.java |  23 ++
 .../internals/DefaultStreamPartitioner.java        |   1 +
 .../processor/internals/RecordCollectorImpl.java   |  24 +-
 .../processor/internals/StreamsMetadataState.java  |  39 ++-
 .../KStreamRepartitionIntegrationTest.java         | 105 +++++-
 ...yInnerJoinCustomPartitionerIntegrationTest.java |  90 +++++
 .../integration/StoreQueryIntegrationTest.java     |  53 +++
 .../integration/utils/IntegrationTestUtils.java    |   1 +
 .../kstream/internals/KStreamRepartitionTest.java  |  10 +-
 .../internals/WindowedStreamPartitionerTest.java   |   1 +
 .../processor/internals/ProcessorTopologyTest.java |  41 +++
 .../processor/internals/RecordCollectorTest.java   | 390 +++++++++++++++++++++
 .../internals/StreamsMetadataStateTest.java        |  42 ++-
 16 files changed, 837 insertions(+), 38 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java 
b/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java
index 9ca495214d6..6461ee7423f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java
@@ -43,10 +43,20 @@ public class KeyQueryMetadata {
 
     private final int partition;
 
+    private final Set<Integer> partitions;
+
     public KeyQueryMetadata(final HostInfo activeHost, final Set<HostInfo> 
standbyHosts, final int partition) {
         this.activeHost = activeHost;
         this.standbyHosts = standbyHosts;
         this.partition = partition;
+        this.partitions = Collections.singleton(partition);
+    }
+
+    public KeyQueryMetadata(final HostInfo activeHost, final Set<HostInfo> 
standbyHosts, final Set<Integer> partitions) {
+        this.activeHost = activeHost;
+        this.standbyHosts = standbyHosts;
+        this.partition = partitions.size() == 1 ? partitions.iterator().next() 
: -1;
+        this.partitions = partitions;
     }
 
     /**
@@ -109,6 +119,16 @@ public class KeyQueryMetadata {
         return partition;
     }
 
+    /**
+     * Get the store partitions corresponding to the key.
+     * A Key can be on multiple partitions if it has been
+     * multicasted using StreamPartitioner#partitions method
+     * @return store partition number
+     */
+    public Set<Integer> partitions() {
+        return partitions;
+    }
+
     @Override
     public boolean equals(final Object obj) {
         if (!(obj instanceof KeyQueryMetadata)) {
@@ -117,7 +137,8 @@ public class KeyQueryMetadata {
         final KeyQueryMetadata keyQueryMetadata = (KeyQueryMetadata) obj;
         return Objects.equals(keyQueryMetadata.activeHost, activeHost)
             && Objects.equals(keyQueryMetadata.standbyHosts, standbyHosts)
-            && Objects.equals(keyQueryMetadata.partition, partition);
+            && (Objects.equals(keyQueryMetadata.partition, partition)
+                || Objects.equals(keyQueryMetadata.partitions, partitions));
     }
 
     @Override
@@ -126,6 +147,7 @@ public class KeyQueryMetadata {
                 "activeHost=" + activeHost +
                 ", standbyHosts=" + standbyHosts +
                 ", partition=" + partition +
+                ", partitions=" + partitions +
                 '}';
     }
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java
index 2abe7f5386b..e34ac2f5841 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java
@@ -78,6 +78,7 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.Optional;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
@@ -1046,7 +1047,18 @@ public class KTableImpl<K, S, V> extends 
AbstractStream<K, V> implements KTable<
         return doJoinOnForeignKey(other, foreignKeyExtractor, joiner, 
TableJoined.with(null, null), materialized, true);
     }
 
-    @SuppressWarnings("unchecked")
+    private final Function<Optional<Set<Integer>>, Integer> getPartition = 
maybeMulticastPartitions -> {
+        if (!maybeMulticastPartitions.isPresent()) {
+            return null;
+        }
+        if (maybeMulticastPartitions.get().size() != 1) {
+            throw new IllegalArgumentException("The partitions returned by 
StreamPartitioner#partitions method when used for FK join should be a singleton 
set");
+        }
+        return maybeMulticastPartitions.get().iterator().next();
+    };
+
+
+    @SuppressWarnings({"unchecked", "deprecation"})
     private <VR, KO, VO> KTable<K, VR> doJoinOnForeignKey(final KTable<KO, VO> 
foreignKeyTable,
                                                           final Function<V, 
KO> foreignKeyExtractor,
                                                           final ValueJoiner<V, 
VO, VR> joiner,
@@ -1069,6 +1081,7 @@ public class KTableImpl<K, S, V> extends 
AbstractStream<K, V> implements KTable<
         enableSendingOldValues(true);
 
         final TableJoinedInternal<K, KO> tableJoinedInternal = new 
TableJoinedInternal<>(tableJoined);
+
         final NamedInternal renamed = new 
NamedInternal(tableJoinedInternal.name());
 
         final String subscriptionTopicName = renamed.suffixWithOrElseGet(
@@ -1118,12 +1131,10 @@ public class KTableImpl<K, S, V> extends 
AbstractStream<K, V> implements KTable<
         );
         builder.addGraphNode(graphNode, subscriptionNode);
 
-
         final StreamPartitioner<KO, SubscriptionWrapper<K>> 
subscriptionSinkPartitioner =
-            tableJoinedInternal.otherPartitioner() == null
-                ? null
-                : (topic, key, val, numPartitions) ->
-                    tableJoinedInternal.otherPartitioner().partition(topic, 
key, null, numPartitions);
+                tableJoinedInternal.otherPartitioner() == null
+                        ? null
+                        : (topic, key, val, numPartitions) -> 
getPartition.apply(tableJoinedInternal.otherPartitioner().partitions(topic, 
key, null, numPartitions));
 
         final StreamSinkNode<KO, SubscriptionWrapper<K>> subscriptionSink = 
new StreamSinkNode<>(
             renamed.suffixWithOrElseGet("-subscription-registration-sink", 
builder, SINK_NAME),
@@ -1196,10 +1207,9 @@ public class KTableImpl<K, S, V> extends 
AbstractStream<K, V> implements KTable<
         
builder.internalTopologyBuilder.addInternalTopic(finalRepartitionTopicName, 
InternalTopicProperties.empty());
 
         final StreamPartitioner<K, SubscriptionResponseWrapper<VO>> 
foreignResponseSinkPartitioner =
-            tableJoinedInternal.partitioner() == null
-                ? (topic, key, subscriptionResponseWrapper, numPartitions) -> 
subscriptionResponseWrapper.getPrimaryPartition()
-                : (topic, key, val, numPartitions) ->
-                    tableJoinedInternal.partitioner().partition(topic, key, 
null, numPartitions);
+                tableJoinedInternal.partitioner() == null
+                        ? (topic, key, subscriptionResponseWrapper, 
numPartitions) -> subscriptionResponseWrapper.getPrimaryPartition()
+                        : (topic, key, val, numPartitions) -> 
getPartition.apply(tableJoinedInternal.partitioner().partitions(topic, key, 
null, numPartitions));
 
         final StreamSinkNode<K, SubscriptionResponseWrapper<VO>> 
foreignResponseSink =
             new StreamSinkNode<>(
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java
 
b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java
index d68a52b8d02..f1ea71981bf 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java
@@ -40,6 +40,7 @@ public class WindowedStreamPartitioner<K, V> implements 
StreamPartitioner<Window
      * @return an integer between 0 and {@code numPartitions-1}, or {@code 
null} if the default partitioning logic should be used
      */
     @Override
+    @Deprecated
     public Integer partition(final String topic, final Windowed<K> 
windowedKey, final V value, final int numPartitions) {
         // for windowed key, the key bytes should never be null
         final byte[] keyBytes = serializer.serializeBaseKey(topic, 
windowedKey);
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java
index 90ffa3a4a83..b4c2483db7d 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java
@@ -18,6 +18,10 @@ package org.apache.kafka.streams.processor;
 
 import org.apache.kafka.streams.Topology;
 
+import java.util.Collections;
+import java.util.Optional;
+import java.util.Set;
+
 /**
  * Determine how records are distributed among the partitions in a Kafka 
topic. If not specified, the underlying producer's
  * {@link org.apache.kafka.clients.producer.internals.DefaultPartitioner} will 
be used to determine the partition.
@@ -58,5 +62,24 @@ public interface StreamPartitioner<K, V> {
      * @param numPartitions the total number of partitions
      * @return an integer between 0 and {@code numPartitions-1}, or {@code 
null} if the default partitioning logic should be used
      */
+    @Deprecated
     Integer partition(String topic, K key, V value, int numPartitions);
+
+    /**
+     * Determine the number(s) of the partition(s) to which a record with the 
given key and value should be sent, 
+     * for the given topic and current partition count
+     * @param topic the topic name this record is sent to
+     * @param key the key of the record
+     * @param value the value of the record
+     * @param numPartitions the total number of partitions
+     * @return an Optional of Set of integers between 0 and {@code 
numPartitions-1},
+     * Empty optional means using default partitioner
+     * Optional of an empty set means the record won't be sent to any 
partitions i.e drop it.
+     * Optional of Set of integers means the partitions to which the record 
should be sent to.
+     * */
+    default Optional<Set<Integer>> partitions(String topic, K key, V value, 
int numPartitions) {
+        final Integer partition = partition(topic, key, value, numPartitions);
+        return partition == null ? Optional.empty() : 
Optional.of(Collections.singleton(partition));
+    }
+
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java
index c7d909c65a3..d51b9791291 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java
@@ -29,6 +29,7 @@ public class DefaultStreamPartitioner<K, V> implements 
StreamPartitioner<K, V> {
     }
 
     @Override
+    @Deprecated
     public Integer partition(final String topic, final K key, final V value, 
final int numPartitions) {
         final byte[] keyBytes = keySerializer.serialize(topic, key);
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
index e42dc4b5735..43c329896f6 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
@@ -50,6 +50,8 @@ import 
org.apache.kafka.streams.processor.internals.metrics.TopicMetrics;
 
 import org.slf4j.Logger;
 
+import java.util.Set;
+import java.util.Optional;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -130,7 +132,6 @@ public class RecordCollectorImpl implements RecordCollector 
{
                             final String processorNodeId,
                             final InternalProcessorContext<Void, Void> context,
                             final StreamPartitioner<? super K, ? super V> 
partitioner) {
-        final Integer partition;
 
         if (partitioner != null) {
             final List<PartitionInfo> partitions;
@@ -150,16 +151,30 @@ public class RecordCollectorImpl implements 
RecordCollector {
                 );
             }
             if (partitions.size() > 0) {
-                partition = partitioner.partition(topic, key, value, 
partitions.size());
+                final Optional<Set<Integer>> maybeMulticastPartitions = 
partitioner.partitions(topic, key, value, partitions.size());
+                if (!maybeMulticastPartitions.isPresent()) {
+                    // A null//empty partition indicates we should use the 
default partitioner
+                    send(topic, key, value, headers, null, timestamp, 
keySerializer, valueSerializer, processorNodeId, context);
+                } else {
+                    final Set<Integer> multicastPartitions = 
maybeMulticastPartitions.get();
+                    if (multicastPartitions.isEmpty()) {
+                        // If a record is not to be sent to any partition, 
mark it as a dropped record.
+                        log.debug("Not sending the record with key {} , value 
{} to any partition", key, value);
+                        droppedRecordsSensor.record();
+                    } else {
+                        for (final int multicastPartition: 
multicastPartitions) {
+                            send(topic, key, value, headers, 
multicastPartition, timestamp, keySerializer, valueSerializer, processorNodeId, 
context);
+                        }
+                    }
+                }
             } else {
                 throw new StreamsException("Could not get partition 
information for topic " + topic + " for task " + taskId +
                     ". This can happen if the topic does not exist.");
             }
         } else {
-            partition = null;
+            send(topic, key, value, headers, null, timestamp, keySerializer, 
valueSerializer, processorNodeId, context);
         }
 
-        send(topic, key, value, headers, partition, timestamp, keySerializer, 
valueSerializer, processorNodeId, context);
     }
 
     @Override
@@ -212,6 +227,7 @@ public class RecordCollectorImpl implements RecordCollector 
{
 
             if (exception == null) {
                 final TopicPartition tp = new TopicPartition(metadata.topic(), 
metadata.partition());
+                log.info("Produced key:{}, value:{} successfully to tp:{}", 
key, value, tp);
                 if (metadata.offset() >= 0L) {
                     offsets.put(tp, metadata.offset());
                 } else {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java
index 61add951b22..7217666bcf5 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java
@@ -43,6 +43,8 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 
+import static 
org.apache.kafka.clients.producer.RecordMetadata.UNKNOWN_PARTITION;
+
 /**
  * Provides access to the {@link StreamsMetadata} in a KafkaStreams 
application. This can be used
  * to discover the locations of {@link 
org.apache.kafka.streams.processor.StateStore}s
@@ -262,9 +264,9 @@ public class StreamsMetadataState {
             // global stores are on every node. if we don't have the host info
             // for this host then just pick the first metadata
             if (thisHost.equals(UNKNOWN_HOST)) {
-                return new KeyQueryMetadata(allMetadata.get(0).hostInfo(), 
Collections.emptySet(), -1);
+                return new KeyQueryMetadata(allMetadata.get(0).hostInfo(), 
Collections.emptySet(), UNKNOWN_PARTITION);
             }
-            return new KeyQueryMetadata(localMetadata.get().hostInfo(), 
Collections.emptySet(), -1);
+            return new KeyQueryMetadata(localMetadata.get().hostInfo(), 
Collections.emptySet(), UNKNOWN_PARTITION);
         }
 
         final SourceTopicsInfo sourceTopicsInfo = 
getSourceTopicsInfo(storeName);
@@ -464,10 +466,20 @@ public class StreamsMetadataState {
                                                            final 
StreamPartitioner<? super K, ?> partitioner,
                                                            final 
SourceTopicsInfo sourceTopicsInfo) {
 
-        final Integer partition = 
partitioner.partition(sourceTopicsInfo.topicWithMostPartitions, key, null, 
sourceTopicsInfo.maxPartitions);
+        // Making an assumption that the partitions method won't return an 
empty Optional set
+        // which means it is not intended to use the default partitioner. It 
is an optimistic
+        // assumption, but the older implementation with partition() also made 
the same assumption.
+        final Set<Integer> partitions = 
partitioner.partitions(sourceTopicsInfo.topicWithMostPartitions, key, null, 
sourceTopicsInfo.maxPartitions).get();
+        // The record was dropped and hence won't be found anywhere
+        if (partitions.isEmpty()) {
+            return new KeyQueryMetadata(UNKNOWN_HOST, Collections.emptySet(), 
UNKNOWN_PARTITION);
+        }
+
         final Set<TopicPartition> matchingPartitions = new HashSet<>();
         for (final String sourceTopic : sourceTopicsInfo.sourceTopics) {
-            matchingPartitions.add(new TopicPartition(sourceTopic, partition));
+            for (final Integer partition : partitions) {
+                matchingPartitions.add(new TopicPartition(sourceTopic, 
partition));
+            }
         }
 
         HostInfo activeHost = UNKNOWN_HOST;
@@ -489,7 +501,7 @@ public class StreamsMetadataState {
             }
         }
 
-        return new KeyQueryMetadata(activeHost, standbyHosts, partition);
+        return new KeyQueryMetadata(activeHost, standbyHosts, partitions);
     }
 
     private <K> KeyQueryMetadata getKeyQueryMetadataForKey(final String 
storeName,
@@ -498,10 +510,21 @@ public class StreamsMetadataState {
                                                            final 
SourceTopicsInfo sourceTopicsInfo,
                                                            final String 
topologyName) {
         Objects.requireNonNull(topologyName, "topology name must not be null");
-        final Integer partition = 
partitioner.partition(sourceTopicsInfo.topicWithMostPartitions, key, null, 
sourceTopicsInfo.maxPartitions);
+
+        // Making an assumption that the partitions method won't return an 
empty Optional set
+        // which means it is not intended to use the default partitioner. It 
is an optimistic
+        // assumption, but the older implementation with partition() also made 
the same assumption.
+        final Set<Integer> partitions = 
partitioner.partitions(sourceTopicsInfo.topicWithMostPartitions, key, null, 
sourceTopicsInfo.maxPartitions).get();
+        // The record was dropped and hence won't be found anywhere
+        if (partitions.isEmpty()) {
+            return new KeyQueryMetadata(UNKNOWN_HOST, Collections.emptySet(), 
UNKNOWN_PARTITION);
+        }
+
         final Set<TopicPartition> matchingPartitions = new HashSet<>();
         for (final String sourceTopic : sourceTopicsInfo.sourceTopics) {
-            matchingPartitions.add(new TopicPartition(sourceTopic, partition));
+            for (final Integer partition : partitions) {
+                matchingPartitions.add(new TopicPartition(sourceTopic, 
partition));
+            }
         }
 
         HostInfo activeHost = UNKNOWN_HOST;
@@ -527,7 +550,7 @@ public class StreamsMetadataState {
             }
         }
 
-        return new KeyQueryMetadata(activeHost, standbyHosts, partition);
+        return new KeyQueryMetadata(activeHost, standbyHosts, partitions);
     }
 
     private SourceTopicsInfo getSourceTopicsInfo(final String storeName) {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java
index 1cff0fc2016..78734b68bb3 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java
@@ -39,6 +39,7 @@ import org.apache.kafka.streams.kstream.JoinWindows;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.Named;
 import org.apache.kafka.streams.kstream.Repartitioned;
+import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.test.IntegrationTest;
 import org.apache.kafka.test.TestUtils;
 import org.junit.After;
@@ -65,12 +66,15 @@ import java.util.List;
 import java.util.Objects;
 import java.util.Properties;
 import java.util.Set;
+import java.util.Optional;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import static org.apache.kafka.streams.KafkaStreams.State.ERROR;
 import static org.apache.kafka.streams.KafkaStreams.State.REBALANCING;
@@ -107,6 +111,8 @@ public class KStreamRepartitionIntegrationTest {
     private String outputTopic;
     private String applicationId;
 
+    private String safeTestName;
+
     private Properties streamsConfiguration;
     private List<KafkaStreams> kafkaStreamsInstances;
 
@@ -129,7 +135,7 @@ public class KStreamRepartitionIntegrationTest {
         streamsConfiguration = new Properties();
         kafkaStreamsInstances = new ArrayList<>();
 
-        final String safeTestName = safeUniqueTestName(getClass(), testName);
+        safeTestName = safeUniqueTestName(getClass(), testName);
 
         topicB = "topic-b-" + safeTestName;
         inputTopic = "input-topic-" + safeTestName;
@@ -293,6 +299,80 @@ public class KStreamRepartitionIntegrationTest {
         );
     }
 
+    @Test
+    public void shouldRepartitionToMultiplePartitions() throws Exception {
+        final String repartitionName = "broadcasting-partitioner-test";
+        final long timestamp = System.currentTimeMillis();
+        final AtomicInteger partitionerInvocation = new AtomicInteger(0);
+
+        // This test needs to write to an output topic with 4 partitions. 
Hence, creating a new one
+        final String broadcastingOutputTopic = "broadcast-output-topic-" + 
safeTestName;
+        CLUSTER.createTopic(broadcastingOutputTopic, 4, 1);
+
+        final List<KeyValue<Integer, String>> expectedRecordsOnRepartition = 
Arrays.asList(
+            new KeyValue<>(1, "A"),
+            new KeyValue<>(1, "A"),
+            new KeyValue<>(1, "A"),
+            new KeyValue<>(1, "A"),
+            new KeyValue<>(2, "B"),
+            new KeyValue<>(2, "B"),
+            new KeyValue<>(2, "B"),
+            new KeyValue<>(2, "B")
+        );
+
+        final List<KeyValue<Integer, String>> expectedRecords = 
expectedRecordsOnRepartition.subList(3, 5);
+
+        class BroadcastingPartitioner implements StreamPartitioner<Integer, 
String> {
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final Integer key, 
final String value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final 
Integer key, final String value, final int numPartitions) {
+                partitionerInvocation.incrementAndGet();
+                return Optional.of(IntStream.range(0, 
numPartitions).boxed().collect(Collectors.toSet()));
+            }
+        }
+
+        sendEvents(timestamp, expectedRecords);
+
+        final StreamsBuilder builder = new StreamsBuilder();
+
+        final Repartitioned<Integer, String> repartitioned = Repartitioned
+            .<Integer, String>as(repartitionName)
+            .withStreamPartitioner(new BroadcastingPartitioner());
+
+        builder.stream(inputTopic, Consumed.with(Serdes.Integer(), 
Serdes.String()))
+            .repartition(repartitioned)
+            .to(broadcastingOutputTopic);
+
+        startStreams(builder);
+
+        final String topic = toRepartitionTopicName(repartitionName);
+
+        // Both records should be there on all 4 partitions of repartition and 
output topic
+        validateReceivedMessages(
+            new IntegerDeserializer(),
+            new StringDeserializer(),
+            expectedRecordsOnRepartition,
+            topic
+        );
+
+
+        validateReceivedMessages(
+            new IntegerDeserializer(),
+            new StringDeserializer(),
+            expectedRecordsOnRepartition,
+            broadcastingOutputTopic
+        );
+
+        assertTrue(topicExists(topic));
+        assertEquals(expectedRecords.size(), partitionerInvocation.get());
+    }
+
+
     @Test
     public void shouldUseStreamPartitionerForRepartitionOperation() throws 
Exception {
         final int partition = 1;
@@ -799,24 +879,33 @@ public class KStreamRepartitionIntegrationTest {
                                                  final Deserializer<V> 
valueSerializer,
                                                  final List<KeyValue<K, V>> 
expectedRecords) throws Exception {
 
+        validateReceivedMessages(keySerializer, valueSerializer, 
expectedRecords, outputTopic);
+    }
+
+    private <K, V> void validateReceivedMessages(final Deserializer<K> 
keySerializer,
+                                                 final Deserializer<V> 
valueSerializer,
+                                                 final List<KeyValue<K, V>> 
expectedRecords,
+                                                 final String outputTopic) 
throws Exception {
+
         final String safeTestName = safeUniqueTestName(getClass(), testName);
         final Properties consumerProperties = new Properties();
         
consumerProperties.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, 
CLUSTER.bootstrapServers());
         consumerProperties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, 
"group-" + safeTestName);
         
consumerProperties.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");
         consumerProperties.setProperty(
-            ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG,
-            keySerializer.getClass().getName()
+                ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG,
+                keySerializer.getClass().getName()
         );
         consumerProperties.setProperty(
-            ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
-            valueSerializer.getClass().getName()
+                ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
+                valueSerializer.getClass().getName()
         );
 
         IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived(
-            consumerProperties,
-            outputTopic,
-            expectedRecords
+                consumerProperties,
+                outputTopic,
+                expectedRecords
         );
     }
+
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java
index b5eb98a31a1..1a9e4635bb1 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java
@@ -28,6 +28,8 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Properties;
 import java.util.Set;
+import java.util.Optional;
+import java.util.Arrays;
 
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.producer.ProducerConfig;
@@ -39,6 +41,7 @@ import org.apache.kafka.streams.KafkaStreams;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.Consumed;
@@ -49,6 +52,7 @@ import org.apache.kafka.streams.kstream.Produced;
 import org.apache.kafka.streams.kstream.Repartitioned;
 import org.apache.kafka.streams.kstream.TableJoined;
 import org.apache.kafka.streams.kstream.ValueJoiner;
+import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.utils.UniqueTopicSerdeScope;
 import org.apache.kafka.test.TestUtils;
@@ -61,6 +65,7 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Tag;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.api.Disabled;
 
 @Timeout(600)
 @Tag("integration")
@@ -83,6 +88,20 @@ public class 
KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest {
     private final static Properties PRODUCER_CONFIG_1 = new Properties();
     private final static Properties PRODUCER_CONFIG_2 = new Properties();
 
+    static class MultiPartitioner implements StreamPartitioner<String, Void> {
+
+        @Override
+        @Deprecated
+        public Integer partition(final String topic, final String key, final 
Void value, final int numPartitions) {
+            return null;
+        }
+
+        @Override
+        public Optional<Set<Integer>> partitions(final String topic, final 
String key, final Void value, final int numPartitions) {
+            return Optional.of(new HashSet<>(Arrays.asList(0, 1, 2)));
+        }
+    }
+
     @BeforeAll
     public static void startCluster() throws IOException, InterruptedException 
{
         CLUSTER.start();
@@ -163,6 +182,35 @@ public class 
KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest {
         verifyKTableKTableJoin(expectedOne);
     }
 
+    @Disabled("This test works individually but fails when run along with the 
class. Ignoring for now.")
+    @Test
+    public void 
shouldThrowIllegalArgumentExceptionWhenCustomPartionerReturnsMultiplePartitions()
 throws Exception {
+        final String innerJoinType = "INNER";
+        final String queryableName = innerJoinType + "-store1";
+
+        streams = prepareTopologyWithNonSingletonPartitions(queryableName, 
streamsConfig);
+        streamsTwo = prepareTopologyWithNonSingletonPartitions(queryableName, 
streamsConfigTwo);
+        streamsThree = 
prepareTopologyWithNonSingletonPartitions(queryableName, streamsConfigThree);
+
+        final List<KafkaStreams> kafkaStreamsList = asList(streams, 
streamsTwo, streamsThree);
+
+        for (final KafkaStreams stream: kafkaStreamsList) {
+            stream.setUncaughtExceptionHandler(e -> {
+                assertEquals("The partitions returned by 
StreamPartitioner#partitions method when used for FK join should be a singleton 
set", e.getCause().getMessage());
+                return 
StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT;
+            });
+        }
+
+        startApplicationAndWaitUntilRunning(kafkaStreamsList, ofSeconds(120));
+
+        // Sleeping to let the processing happen inducing the failure
+        Thread.sleep(60000);
+
+        assertEquals(KafkaStreams.State.ERROR, streams.state());
+        assertEquals(KafkaStreams.State.ERROR, streamsTwo.state());
+        assertEquals(KafkaStreams.State.ERROR, streamsThree.state());
+    }
+
     private void verifyKTableKTableJoin(final Set<KeyValue<String, String>> 
expectedResult) throws Exception {
         final String innerJoinType = "INNER";
         final String queryableName = innerJoinType + "-store1";
@@ -235,6 +283,48 @@ public class 
KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest {
         return new KafkaStreams(builder.build(streamsConfig), streamsConfig);
     }
 
+    private static KafkaStreams 
prepareTopologyWithNonSingletonPartitions(final String queryableName, final 
Properties streamsConfig) {
+
+        final UniqueTopicSerdeScope serdeScope = new UniqueTopicSerdeScope();
+        final StreamsBuilder builder = new StreamsBuilder();
+
+        final KTable<String, String> table1 = builder.stream(TABLE_1,
+                        
Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), 
serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)))
+                .repartition(repartitionA())
+                .toTable(Named.as("table.a"));
+
+        final KTable<String, String> table2 = builder
+                .stream(TABLE_2,
+                        
Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), 
serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)))
+                .repartition(repartitionB())
+                .toTable(Named.as("table.b"));
+
+        final Materialized<String, String, KeyValueStore<Bytes, byte[]>> 
materialized;
+        if (queryableName != null) {
+            materialized = Materialized.<String, String, KeyValueStore<Bytes, 
byte[]>>as(queryableName)
+                    .withKeySerde(serdeScope.decorateSerde(Serdes.String(), 
streamsConfig, true))
+                    .withValueSerde(serdeScope.decorateSerde(Serdes.String(), 
streamsConfig, false))
+                    .withCachingDisabled();
+        } else {
+            throw new RuntimeException("Current implementation of 
joinOnForeignKey requires a materialized store");
+        }
+
+        final ValueJoiner<String, String, String> joiner = (value1, value2) -> 
"value1=" + value1 + ",value2=" + value2;
+
+        final TableJoined<String, String> tableJoined = TableJoined.with(
+                new MultiPartitioner(),
+                (topic, key, value, numPartitions) -> Math.abs(key.hashCode()) 
% numPartitions
+        );
+
+        table1.join(table2, 
KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest::getKeyB, 
joiner, tableJoined, materialized)
+                .toStream()
+                .to(OUTPUT,
+                        
Produced.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true),
+                                serdeScope.decorateSerde(Serdes.String(), 
streamsConfig, false)));
+
+        return new KafkaStreams(builder.build(streamsConfig), streamsConfig);
+    }
+
     private static Repartitioned<String, String> repartitionA() {
         final Repartitioned<String, String> repartitioned = 
Repartitioned.as("a");
         return 
repartitioned.withKeySerde(Serdes.String()).withValueSerde(Serdes.String())
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java
index 85595cefc3f..44e690b3000 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java
@@ -26,6 +26,7 @@ import org.apache.kafka.streams.KafkaStreams;
 import org.apache.kafka.streams.KafkaStreams.State;
 import org.apache.kafka.streams.KeyQueryMetadata;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.KeyValueTimestamp;
 import org.apache.kafka.streams.StoreQueryParameters;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
@@ -34,6 +35,7 @@ import 
org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.Consumed;
 import org.apache.kafka.streams.kstream.Materialized;
+import org.apache.kafka.streams.processor.StreamPartitioner;
 import 
org.apache.kafka.streams.processor.internals.namedtopology.KafkaStreamsNamedTopologyWrapper;
 import 
org.apache.kafka.streams.processor.internals.namedtopology.NamedTopologyBuilder;
 import 
org.apache.kafka.streams.processor.internals.namedtopology.NamedTopologyStoreQueryParameters;
@@ -65,8 +67,11 @@ import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+import java.util.Set;
+import java.util.Collections;
 
 import static java.util.Collections.singletonList;
+import static org.apache.kafka.common.utils.Utils.mkSet;
 import static 
org.apache.kafka.streams.integration.utils.IntegrationTestUtils.getStore;
 import static 
org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName;
 import static 
org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning;
@@ -117,6 +122,54 @@ public class StoreQueryIntegrationTest {
         cluster.stop();
     }
 
+    @Test
+    public void shouldReturnAllPartitionsWhenRecordIsBroadcast() throws 
Exception {
+
+        class BroadcastingPartitioner implements StreamPartitioner<Integer, 
String> {
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final Integer key, 
final String value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final 
Integer key, final String value, final int numPartitions) {
+                return Optional.of(IntStream.range(0, 
numPartitions).boxed().collect(Collectors.toSet()));
+            }
+        }
+
+        final int batch1NumMessages = 1;
+        final int key = 1;
+        final Semaphore semaphore = new Semaphore(0);
+
+        final StreamsBuilder builder = new StreamsBuilder();
+        getStreamsBuilderWithTopology(builder, semaphore);
+
+        final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, 
streamsConfiguration());
+
+        
startApplicationAndWaitUntilRunning(Collections.singletonList(kafkaStreams1), 
Duration.ofSeconds(60));
+
+        final Properties producerProps = new Properties();
+        producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, 
cluster.bootstrapServers());
+        producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, 
IntegerSerializer.class);
+        producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, 
IntegerSerializer.class);
+
+        final List<KeyValueTimestamp<Integer, Integer>> records = 
Collections.singletonList(new KeyValueTimestamp<>(key, 0, 
mockTime.milliseconds()));
+
+        // Send the record to both partitions of INPUT_TOPIC_NAME.
+        IntegrationTestUtils.produceSynchronously(producerProps, false, 
INPUT_TOPIC_NAME, Optional.of(0), records);
+        IntegrationTestUtils.produceSynchronously(producerProps, false, 
INPUT_TOPIC_NAME, Optional.of(1), records);
+
+        assertThat(semaphore.tryAcquire(batch1NumMessages, 60, 
TimeUnit.SECONDS), is(equalTo(true)));
+
+        until(() -> {
+            final KeyQueryMetadata keyQueryMetadataFetched = 
kafkaStreams1.queryMetadataForKey(TABLE_NAME, key, new 
BroadcastingPartitioner());
+            assertThat(keyQueryMetadataFetched.activeHost().host(), 
is("localhost"));
+            assertThat(keyQueryMetadataFetched.partitions(), is(mkSet(0, 1)));
+            return true;
+        });
+    }
+
     @Test
     public void shouldQueryOnlyActivePartitionStoresByDefault() throws 
Exception {
         final int batch1NumMessages = 100;
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
 
b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
index c14988cdae4..f0915c8b88e 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
@@ -1283,6 +1283,7 @@ public class IntegrationTestUtils {
                                                                  final int 
maxMessages) {
         final List<ConsumerRecord<K, V>> consumerRecords;
         consumer.subscribe(Collections.singletonList(topic));
+        System.out.println("Got assignment:" + consumer.assignment());
         final int pollIntervalMs = 100;
         consumerRecords = new ArrayList<>();
         int totalPollTimeMs = 0;
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java
index c7669a978a1..9dfabeacfe8 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java
@@ -45,6 +45,8 @@ import java.time.Instant;
 import java.util.Map;
 import java.util.Properties;
 import java.util.TreeMap;
+import java.util.Optional;
+import java.util.Collections;
 
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -78,8 +80,8 @@ public class KStreamRepartitionTest {
         @SuppressWarnings("unchecked")
         final StreamPartitioner<Integer, String> streamPartitionerMock = 
mock(StreamPartitioner.class);
 
-        when(streamPartitionerMock.partition(anyString(), eq(0), eq("X0"), 
anyInt())).thenReturn(1);
-        when(streamPartitionerMock.partition(anyString(), eq(1), eq("X1"), 
anyInt())).thenReturn(1);
+        when(streamPartitionerMock.partitions(anyString(), eq(0), eq("X0"), 
anyInt())).thenReturn(Optional.of(Collections.singleton(1)));
+        when(streamPartitionerMock.partitions(anyString(), eq(1), eq("X1"), 
anyInt())).thenReturn(Optional.of(Collections.singleton(1)));
 
         final String repartitionOperationName = "test";
         final Repartitioned<Integer, String> repartitioned = Repartitioned
@@ -111,8 +113,8 @@ public class KStreamRepartitionTest {
             assertTrue(testOutputTopic.readRecordsToList().isEmpty());
         }
 
-        verify(streamPartitionerMock).partition(anyString(), eq(0), eq("X0"), 
anyInt());
-        verify(streamPartitionerMock).partition(anyString(), eq(1), eq("X1"), 
anyInt());
+        verify(streamPartitionerMock).partitions(anyString(), eq(0), eq("X0"), 
anyInt());
+        verify(streamPartitionerMock).partitions(anyString(), eq(1), eq("X1"), 
anyInt());
     }
 
     @Test
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java
index a6595257277..17ed8eac97a 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java
@@ -72,6 +72,7 @@ public class WindowedStreamPartitionerTest {
                 final TimeWindow window = new TimeWindow(10 * w, 20 * w);
 
                 final Windowed<Integer> windowedKey = new Windowed<>(key, 
window);
+                @SuppressWarnings("deprecation")
                 final Integer actual = streamPartitioner.partition(topicName, 
windowedKey, value, infos.size());
 
                 assertEquals(expected, actual);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
index de26c7bfaed..e4d015c991a 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java
@@ -60,6 +60,8 @@ import java.util.Properties;
 import java.util.Set;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Optional;
+import java.util.HashSet;
 import java.util.function.Supplier;
 
 import static java.util.Arrays.asList;
@@ -298,6 +300,17 @@ public class ProcessorTopologyTest {
         assertTrue(outputTopic1.isEmpty());
     }
 
+    @Test
+    public void testDrivingSimpleTopologyWithDroppingPartitioner() {
+        driver = new 
TopologyTestDriver(createSimpleTopologyWithDroppingPartitioner(), props);
+        final TestInputTopic<String, String> inputTopic = 
driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER, 
Instant.ofEpochMilli(0L), Duration.ZERO);
+        final TestOutputTopic<String, String> outputTopic1 =
+                driver.createOutputTopic(OUTPUT_TOPIC_1, 
Serdes.String().deserializer(), Serdes.String().deserializer());
+
+        inputTopic.pipeInput("key1", "value1");
+        assertTrue(outputTopic1.isEmpty());
+    }
+
     @Test
     public void testDrivingStatefulTopology() {
         final String storeName = "entries";
@@ -1583,6 +1596,34 @@ public class ProcessorTopologyTest {
             .addSink("sink2", OUTPUT_TOPIC_2, constantPartitioner(partition), 
"child2");
     }
 
+    static class DroppingPartitioner implements StreamPartitioner<String, 
String> {
+
+        @Override
+        @Deprecated
+        public Integer partition(final String topic, final String key, final 
String value, final int numPartitions) {
+            return null;
+        }
+
+        @Override
+        public Optional<Set<Integer>> partitions(final String topic, final 
String key, final String value, final int numPartitions) {
+            final Set<Integer> partitions = new HashSet<>();
+            for (int i = 1; i < numPartitions; i += 2) {
+                partitions.add(i);
+            }
+            return Optional.of(partitions);
+        }
+    }
+
+    // Adding a test only for dropping partitioner as the output topic is a 
single partitioned topic
+    // and the default implementation of partitions method already sends a 
singleton list which is
+    // getting tested in other tests
+    private Topology createSimpleTopologyWithDroppingPartitioner() {
+        return topology
+                .addSource("source", STRING_DESERIALIZER, STRING_DESERIALIZER, 
INPUT_TOPIC_1)
+                .addProcessor("processor", ForwardingProcessor::new, "source")
+                .addSink("sink", OUTPUT_TOPIC_1, new DroppingPartitioner(), 
"processor");
+    }
+
     @Deprecated // testing old PAPI
     private Topology createStatefulTopology(final String storeName) {
         return topology
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
index b272ea609e1..9f7d39d25cf 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
@@ -67,6 +67,11 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.Optional;
+import java.util.Set;
+import java.util.HashSet;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
@@ -291,6 +296,391 @@ public class RecordCollectorTest {
         assertThrows(UnsupportedOperationException.class, () -> 
offsets.put(topicPartition, 50L));
     }
 
+    @Test
+    public void shouldSendOnlyToEvenPartitions() {
+        class EvenPartitioner implements StreamPartitioner<String, Object> {
+
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final String key, 
final Object value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final 
String key, final Object value, final int numPartitions) {
+                final Set<Integer> partitions = new HashSet<>();
+                for (int i = 0; i < numPartitions; i += 2) {
+                    partitions.add(i);
+                }
+                return Optional.of(partitions);
+            }
+        }
+
+        final EvenPartitioner evenPartitioner = new EvenPartitioner();
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                evenPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final Headers headers = new RecordHeaders(new Header[] {new 
RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, 
stringSerializer, null, context, evenPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertFalse(offsets.containsKey(new TopicPartition(topic, 1)));
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(18, mockProducer.history().size());
+
+        // returned offsets should not be modified
+        final TopicPartition topicPartition = new TopicPartition(topic, 0);
+        assertThrows(UnsupportedOperationException.class, () -> 
offsets.put(topicPartition, 50L));
+    }
+
+    @Test
+    public void shouldBroadcastToAllPartitions() {
+
+        class BroadcastingPartitioner implements StreamPartitioner<String, 
Object> {
+
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final String key, 
final Object value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final 
String key, final Object value, final int numPartitions) {
+                return Optional.of(IntStream.range(0, 
numPartitions).boxed().collect(Collectors.toSet()));
+            }
+        }
+
+        final BroadcastingPartitioner broadcastingPartitioner = new 
BroadcastingPartitioner();
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                broadcastingPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final Headers headers = new RecordHeaders(new Header[] {new 
RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, 
stringSerializer, null, context, broadcastingPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 1)));
+        assertEquals(8L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(27, mockProducer.history().size());
+
+        // returned offsets should not be modified
+        final TopicPartition topicPartition = new TopicPartition(topic, 0);
+        assertThrows(UnsupportedOperationException.class, () -> 
offsets.put(topicPartition, 50L));
+    }
+
+    @Test
+    public void shouldDropAllRecords() {
+
+        class DroppingPartitioner implements StreamPartitioner<String, Object> 
{
+
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final String key, 
final Object value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final 
String key, final Object value, final int numPartitions) {
+                return Optional.of(Collections.emptySet());
+            }
+        }
+
+        final DroppingPartitioner droppingPartitioner = new 
DroppingPartitioner();
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                droppingPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final String topic = "topic";
+
+        final Metric recordsDropped = streamsMetrics.metrics().get(new 
MetricName(
+                "dropped-records-total",
+                "stream-task-metrics",
+                "The total number of dropped records",
+                mkMap(
+                        mkEntry("thread-id", Thread.currentThread().getName()),
+                        mkEntry("task-id", taskId.toString())
+                )
+        ));
+
+
+        final Headers headers = new RecordHeaders(new Header[] {new 
RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, 
stringSerializer, null, context, droppingPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+        assertTrue(offsets.isEmpty());
+
+        assertEquals(0, mockProducer.history().size());
+        assertThat(recordsDropped.metricValue(), equalTo(9.0));
+
+        // returned offsets should not be modified
+        final TopicPartition topicPartition = new TopicPartition(topic, 0);
+        assertThrows(UnsupportedOperationException.class, () -> 
offsets.put(topicPartition, 50L));
+    }
+
+    @Test
+    public void shouldUseDefaultPartitionerViaPartitions() {
+
+        class DefaultPartitioner implements StreamPartitioner<String, Object> {
+
+            @Override
+            @Deprecated
+            public Integer partition(final String topic, final String key, 
final Object value, final int numPartitions) {
+                return null;
+            }
+
+            @Override
+            public Optional<Set<Integer>> partitions(final String topic, final 
String key, final Object value, final int numPartitions) {
+                return Optional.empty();
+            }
+        }
+
+        final DefaultPartitioner defaultPartitioner = new DefaultPartitioner();
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                defaultPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final String topic = "topic";
+
+        final Headers headers = new RecordHeaders(new Header[] {new 
RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, 
stringSerializer, null, context, defaultPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        // with mock producer without specific partition, we would use default 
producer partitioner with murmur hash
+        assertEquals(3L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1)));
+        assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(9, mockProducer.history().size());
+    }
+
+    @Test
+    public void shouldUseDefaultPartitionerAsPartitionReturnsNull() {
+
+        final StreamPartitioner<String, Object> streamPartitioner =
+                (topic, key, value, numPartitions) -> null;
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                streamPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final String topic = "topic";
+
+        final Headers headers = new RecordHeaders(new Header[] {new 
RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, 
stringSerializer, null, context, streamPartitioner);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        // with mock producer without specific partition, we would use default 
producer partitioner with murmur hash
+        assertEquals(3L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1)));
+        assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(9, mockProducer.history().size());
+    }
+
+    @Test
+    public void shouldUseDefaultPartitionerAsStreamPartitionerIsNull() {
+
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+                sinkNodeName,
+                new StaticTopicNameExtractor<>(topic),
+                stringSerializer,
+                byteArraySerializer,
+                streamPartitioner);
+        topology = new ProcessorTopology(
+                emptyList(),
+                emptyMap(),
+                singletonMap(topic, sinkNode),
+                emptyList(),
+                emptyList(),
+                emptyMap(),
+                emptySet()
+        );
+        collector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                streamsProducer,
+                productionExceptionHandler,
+                streamsMetrics,
+                topology
+        );
+
+        final String topic = "topic";
+
+        final Headers headers = new RecordHeaders(new Header[] {new 
RecordHeader("key", "value".getBytes())});
+
+        collector.send(topic, "3", "0", null, null, stringSerializer, 
stringSerializer, null, context, null);
+        collector.send(topic, "9", "0", null, null, stringSerializer, 
stringSerializer, null, context, null);
+        collector.send(topic, "27", "0", null, null, stringSerializer, 
stringSerializer, null, context, null);
+        collector.send(topic, "81", "0", null, null, stringSerializer, 
stringSerializer, null, context, null);
+        collector.send(topic, "243", "0", null, null, stringSerializer, 
stringSerializer, null, context, null);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, 
stringSerializer, null, context, null);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, 
stringSerializer, null, context, null);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, 
stringSerializer, null, context, null);
+        collector.send(topic, "245", "0", null, null, stringSerializer, 
stringSerializer, null, context, null);
+
+        final Map<TopicPartition, Long> offsets = collector.offsets();
+
+        // with mock producer without specific partition, we would use default 
producer partitioner with murmur hash
+        assertEquals(3L, (long) offsets.get(new TopicPartition(topic, 0)));
+        assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1)));
+        assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 2)));
+        assertEquals(9, mockProducer.history().size());
+    }
+
     @Test
     public void shouldSendWithNoPartition() {
         final Headers headers = new RecordHeaders(new Header[] {new 
RecordHeader("key", "value".getBytes())});
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java
index d44b79885b4..e04acca447a 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java
@@ -45,6 +45,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.Optional;
+import java.util.HashSet;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -136,6 +138,23 @@ public class StreamsMetadataStateTest {
         storeNames = mkSet("table-one", "table-two", "merged-table", 
globalTable);
     }
 
+    static class MultiValuedPartitioner implements StreamPartitioner<String, 
Object> {
+
+        @Override
+        @Deprecated
+        public Integer partition(final String topic, final String key, final 
Object value, final int numPartitions) {
+            return null;
+        }
+
+        @Override
+        public Optional<Set<Integer>> partitions(final String topic, final 
String key, final Object value, final int numPartitions) {
+            final Set<Integer> partitions = new HashSet<>();
+            partitions.add(0);
+            partitions.add(1);
+            return Optional.of(partitions);
+        }
+    }
+
     @Test
     public void shouldNotThrowExceptionWhenOnChangeNotCalled() {
         final Collection<StreamsMetadata> metadata = new StreamsMetadataState(
@@ -229,7 +248,7 @@ public class StreamsMetadataStateTest {
         metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions,
             cluster.withPartitions(Collections.singletonMap(tp4, new 
PartitionInfo("topic-three", 1, null, null, null))));
 
-        final KeyQueryMetadata expected = new KeyQueryMetadata(hostThree, 
mkSet(hostTwo), 0);
+        final KeyQueryMetadata expected = new KeyQueryMetadata(hostThree, 
mkSet(hostTwo), Collections.singleton(0));
         final KeyQueryMetadata actual = 
metadataState.getKeyQueryMetadataForKey("table-three",
                                                                     "the-key",
                                                                     
Serdes.String().serializer());
@@ -244,13 +263,30 @@ public class StreamsMetadataStateTest {
         metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions,
             cluster.withPartitions(Collections.singletonMap(tp4, new 
PartitionInfo("topic-three", 1, null, null, null))));
 
-        final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, 
Collections.emptySet(), 1);
+        final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, 
Collections.emptySet(), Collections.singleton(1));
 
         final KeyQueryMetadata actual = 
metadataState.getKeyQueryMetadataForKey("table-three",
                 "the-key",
                 partitioner);
         assertEquals(expected, actual);
         assertEquals(1, actual.partition());
+        assertEquals(Collections.singleton(1), actual.partitions());
+    }
+
+    @Test
+    public void shouldGetInstanceWithKeyAndCustomMulticastingPartitioner() {
+        final TopicPartition tp4 = new TopicPartition("topic-three", 0);
+        final TopicPartition tp5 = new TopicPartition("topic-three", 1);
+        hostToActivePartitions.put(hostTwo, mkSet(tp4, tp5));
+
+        final KeyQueryMetadata expected = new KeyQueryMetadata(hostThree, 
Collections.singleton(hostTwo), mkSet(0, 1));
+
+        final KeyQueryMetadata actual = 
metadataState.getKeyQueryMetadataForKey("table-three",
+                "the-key",
+                new MultiValuedPartitioner());
+        assertEquals(expected, actual);
+        assertEquals(-1, actual.partition());
+        assertEquals(mkSet(0, 1), actual.partitions());
     }
 
     @Test
@@ -268,7 +304,7 @@ public class StreamsMetadataStateTest {
         metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions,
                 cluster.withPartitions(Collections.singletonMap(topic2P2, new 
PartitionInfo("topic-two", 2, null, null, null))));
 
-        final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, 
mkSet(hostOne), 2);
+        final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, 
mkSet(hostOne), Collections.singleton(2));
 
         final KeyQueryMetadata actual = 
metadataState.getKeyQueryMetadataForKey("merged-table",  "the-key",
             (topic, key, value, numPartitions) -> 2);

Reply via email to