You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2020/05/07 10:56:51 UTC

[GitHub] [flink] AHeise commented on a change in pull request #11725: [FLINK-15670][API] Provide a Kafka Source/Sink pair as KafkaShuffle

AHeise commented on a change in pull request #11725:
URL: https://github.com/apache/flink/pull/11725#discussion_r421296506



##########
File path: flink-core/src/test/java/org/apache/flink/util/PropertiesUtilTest.java
##########
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.util;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.util.PropertiesUtil.flatten;
+
+/**
+ * Tests for the {@link PropertiesUtil}.
+ */
+public class PropertiesUtilTest {
+
+	@Test
+	public void testFlatten() {
+		// default Properties is null
+		Properties prop1 = new Properties();
+		prop1.put("key1", "value1");
+
+		// default Properties is prop1
+		Properties prop2 = new Properties(prop1);
+		prop2.put("key2", "value2");
+
+		// default Properties is prop2
+		Properties prop3 = new Properties(prop2);
+		prop3.put("key3", "value3");
+
+		Properties flattened = flatten(prop3);
+		Assert.assertEquals(flattened.get("key1"), prop3.getProperty("key1"));
+		Assert.assertEquals(flattened.get("key2"), prop3.getProperty("key2"));
+		Assert.assertEquals(flattened.get("key3"), prop3.getProperty("key3"));
+		Assert.assertNotEquals(flattened.get("key1"), prop3.get("key1"));
+		Assert.assertNotEquals(flattened.get("key2"), prop3.get("key2"));
+		Assert.assertEquals(flattened.get("key3"), prop3.get("key3"));

Review comment:
       I'd remove these assertions and only keep the last 3.

##########
File path: flink-core/src/main/java/org/apache/flink/util/PropertiesUtil.java
##########
@@ -108,6 +109,27 @@ public static boolean getBoolean(Properties config, String key, boolean defaultV
 		}
 	}
 
+	/**
+	 * Flatten a recursive {@link Properties} to a first level property map.
+	 * In some cases, {KafkaProducer#propsToMap} for example, Properties is used purely as a HashMap
+	 * without considering its default properties.
+	 *
+	 * @param config Properties to be flatten
+	 * @return Properties without defaults; all properties are put in the first-level
+	 */
+	public static Properties flatten(Properties config) {

Review comment:
       > The flattened properties are actually used in the Kafka client lib, not that easy to fix.
   
   Does that mean that Kafka is actually not processing the recursive Properties correctly? We should probably file a bug report then.

##########
File path: flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
##########
@@ -166,6 +167,13 @@ protected TXN currentTransaction() {
 	 */
 	protected abstract void invoke(TXN transaction, IN value, Context context) throws Exception;
 
+	/**
+	 * Handle watermark within a transaction.
+	 */
+	protected void invoke(TXN transaction, Watermark watermark) throws Exception {
+		throw new UnsupportedOperationException("invokeWithWatermark should not be invoked");
+	}
+

Review comment:
       Not necessary. `KafkaShuffleProducer` can use the protected `currentTransaction()`.

##########
File path: flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
##########
@@ -235,6 +243,10 @@ public final void invoke(
 		invoke(currentTransactionHolder.handle, value, context);
 	}
 
+	public final void invoke(Watermark watermark) throws Exception {
+		invoke(currentTransactionHolder.handle, watermark);
+	}
+

Review comment:
       Not necessary. `KafkaShuffleProducer` can use the protected `currentTransaction()`.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
##########
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import org.apache.kafka.clients.producer.ProducerRecord;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER;
+
+/**
+ * Flink Kafka Shuffle Producer Function.
+ * It is different from {@link FlinkKafkaProducer} in the way handling elements and watermarks
+ */
+@Internal
+public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> {
+	private final KafkaSerializer<IN> kafkaSerializer;
+	private final KeySelector<IN, KEY> keySelector;
+	private final int numberOfPartitions;
+
+	FlinkKafkaShuffleProducer(
+		String defaultTopicId,

Review comment:
       Nit: Method parameters must be double-indented. Cannot be automatically done with IntelliJ unfortunately :(.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
##########
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import org.apache.kafka.clients.producer.ProducerRecord;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER;
+
+/**
+ * Flink Kafka Shuffle Producer Function.
+ * It is different from {@link FlinkKafkaProducer} in the way handling elements and watermarks
+ */
+@Internal
+public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> {
+	private final KafkaSerializer<IN> kafkaSerializer;
+	private final KeySelector<IN, KEY> keySelector;
+	private final int numberOfPartitions;
+
+	FlinkKafkaShuffleProducer(
+		String defaultTopicId,
+		TypeInformationSerializationSchema<IN> schema,
+		Properties props,
+		KeySelector<IN, KEY> keySelector,
+		Semantic semantic,
+		int kafkaProducersPoolSize) {
+		super(defaultTopicId, (element, timestamp) -> null, props, semantic, kafkaProducersPoolSize);
+
+		this.kafkaSerializer = new KafkaSerializer<>(schema.getSerializer());
+		this.keySelector = keySelector;
+
+		Preconditions.checkArgument(
+			props.getProperty(PARTITION_NUMBER) != null,
+			"Missing partition number for Kafka Shuffle");
+		numberOfPartitions = PropertiesUtil.getInt(props, PARTITION_NUMBER, Integer.MIN_VALUE);
+	}
+
+	/**
+	 * This is the function invoked to handle each element.
+	 * @param transaction transaction state;
+	 *                    elements are written to Kafka in transactions to guarantee different level of data consistency
+	 * @param next element to handle
+	 * @param context context needed to handle the element
+	 * @throws FlinkKafkaException for kafka error
+	 */
+	@Override
+	public void invoke(KafkaTransactionState transaction, IN next, Context context) throws FlinkKafkaException {
+		checkErroneous();
+
+		// write timestamp to Kafka if timestamp is available
+		Long timestamp = context.timestamp();
+
+		int[] partitions = getPartitions(transaction);
+		int partitionIndex;
+		try {
+			partitionIndex = KeyGroupRangeAssignment
+				.assignKeyToParallelOperator(keySelector.getKey(next), partitions.length, partitions.length);
+		} catch (Exception e) {
+			throw new RuntimeException("Fail to assign a partition number to record");
+		}
+
+		ProducerRecord<byte[], byte[]> record = new ProducerRecord<>(
+			defaultTopicId, partitionIndex, timestamp, null, kafkaSerializer.serializeRecord(next, timestamp));
+		pendingRecords.incrementAndGet();
+		transaction.getProducer().send(record, callback);
+	}
+
+	/**
+	 * This is the function invoked to handle each watermark.
+	 * @param transaction transaction state;
+	 *                    watermark are written to Kafka (if needed) in transactions
+	 * @param watermark watermark to handle
+	 * @throws FlinkKafkaException for kafka error
+	 */
+	@Override
+	public void invoke(KafkaTransactionState transaction, Watermark watermark) throws FlinkKafkaException {
+		checkErroneous();
+
+		int[] partitions = getPartitions(transaction);
+		int subtask = getRuntimeContext().getIndexOfThisSubtask();
+
+		// broadcast watermark
+		long timestamp = watermark.getTimestamp();
+		for (int partition : partitions) {
+			ProducerRecord<byte[], byte[]> record = new ProducerRecord<>(
+				defaultTopicId, partition, timestamp, null, kafkaSerializer.serializeWatermark(watermark, subtask));
+			pendingRecords.incrementAndGet();
+			transaction.getProducer().send(record, callback);
+		}
+	}
+
+	private int[] getPartitions(KafkaTransactionState transaction) {
+		int[] partitions = topicPartitionsMap.get(defaultTopicId);
+		if (partitions == null) {
+			partitions = getPartitionsByTopic(defaultTopicId, transaction.getProducer());
+			topicPartitionsMap.put(defaultTopicId, partitions);
+		}
+
+		Preconditions.checkArgument(partitions.length == numberOfPartitions);
+
+		return partitions;
+	}
+
+	/**
+	 * Flink Kafka Shuffle Serializer.
+	 */
+	public static final class KafkaSerializer<IN> implements Serializable {
+		public static final int TAG_REC_WITH_TIMESTAMP = 0;
+		public static final int TAG_REC_WITHOUT_TIMESTAMP = 1;
+		public static final int TAG_WATERMARK = 2;

Review comment:
       Is this supposed to be public or not? It should probably be package-private.
   
   I was also thinking of pulling it as top-level class, which then also incorporates the deserializing stuff of the next commit.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> {
+	private static final Logger LOG = LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+	private final WatermarkHandler watermarkHandler;
+	// ------------------------------------------------------------------------
+
+	/** The schema to convert between Kafka's byte messages, and Flink's objects. */
+	private final KafkaShuffleElementDeserializer<T> deserializer;
+
+	/** Serializer to serialize record. */
+	private final TypeSerializer<T> serializer;
+
+	/** The handover of data and exceptions between the consumer thread and the task thread. */
+	private final Handover handover;
+
+	/** The thread that runs the actual KafkaConsumer and hand the record batches to this fetcher. */
+	private final KafkaConsumerThread consumerThread;
+
+	/** Flag to mark the main work loop as alive. */
+	private volatile boolean running = true;
+
+	public KafkaShuffleFetcher(
+		SourceFunction.SourceContext<T> sourceContext,
+		Map<KafkaTopicPartition, Long> assignedPartitionsWithInitialOffsets,
+		SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
+		SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
+		ProcessingTimeService processingTimeProvider,
+		long autoWatermarkInterval,
+		ClassLoader userCodeClassLoader,
+		String taskNameWithSubtasks,
+		TypeSerializer<T> serializer,
+		Properties kafkaProperties,
+		long pollTimeout,
+		MetricGroup subtaskMetricGroup,
+		MetricGroup consumerMetricGroup,
+		boolean useMetrics,
+		int producerParallelism) throws Exception {
+		super(
+			sourceContext,
+			assignedPartitionsWithInitialOffsets,
+			watermarksPeriodic,
+			watermarksPunctuated,
+			processingTimeProvider,
+			autoWatermarkInterval,
+			userCodeClassLoader,
+			consumerMetricGroup,
+			useMetrics);
+
+		this.deserializer = new KafkaShuffleElementDeserializer<>();
+		this.serializer = serializer;
+		this.handover = new Handover();
+		this.consumerThread = new KafkaConsumerThread(
+			LOG,
+			handover,
+			kafkaProperties,
+			unassignedPartitionsQueue,
+			getFetcherName() + " for " + taskNameWithSubtasks,
+			pollTimeout,
+			useMetrics,
+			consumerMetricGroup,
+			subtaskMetricGroup);
+		this.watermarkHandler = new WatermarkHandler(producerParallelism);
+	}
+
+	// ------------------------------------------------------------------------
+	//  Fetcher work methods
+	// ------------------------------------------------------------------------
+
+	@Override
+	public void runFetchLoop() throws Exception {
+		try {
+			final Handover handover = this.handover;
+
+			// kick off the actual Kafka consumer
+			consumerThread.start();
+
+			while (running) {
+				// this blocks until we get the next records
+				// it automatically re-throws exceptions encountered in the consumer thread
+				final ConsumerRecords<byte[], byte[]> records = handover.pollNext();
+
+				// get the records for each topic partition
+				for (KafkaTopicPartitionState<TopicPartition> partition : subscribedPartitionStates()) {
+					List<ConsumerRecord<byte[], byte[]>> partitionRecords =
+						records.records(partition.getKafkaPartitionHandle());
+
+					for (ConsumerRecord<byte[], byte[]> record : partitionRecords) {
+						final KafkaShuffleElement<T> element = deserializer.deserialize(serializer, record);
+
+						// TODO: do we need to check the end of stream if reaching the end watermark?

Review comment:
       I'd assume so, or else bounded inputs won't work well.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> {
+	private static final Logger LOG = LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+	private final WatermarkHandler watermarkHandler;
+	// ------------------------------------------------------------------------
+
+	/** The schema to convert between Kafka's byte messages, and Flink's objects. */
+	private final KafkaShuffleElementDeserializer<T> deserializer;
+
+	/** Serializer to serialize record. */
+	private final TypeSerializer<T> serializer;
+
+	/** The handover of data and exceptions between the consumer thread and the task thread. */
+	private final Handover handover;
+
+	/** The thread that runs the actual KafkaConsumer and hand the record batches to this fetcher. */
+	private final KafkaConsumerThread consumerThread;
+
+	/** Flag to mark the main work loop as alive. */
+	private volatile boolean running = true;
+
+	public KafkaShuffleFetcher(
+		SourceFunction.SourceContext<T> sourceContext,
+		Map<KafkaTopicPartition, Long> assignedPartitionsWithInitialOffsets,
+		SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
+		SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
+		ProcessingTimeService processingTimeProvider,
+		long autoWatermarkInterval,
+		ClassLoader userCodeClassLoader,
+		String taskNameWithSubtasks,
+		TypeSerializer<T> serializer,
+		Properties kafkaProperties,
+		long pollTimeout,
+		MetricGroup subtaskMetricGroup,
+		MetricGroup consumerMetricGroup,
+		boolean useMetrics,
+		int producerParallelism) throws Exception {
+		super(
+			sourceContext,
+			assignedPartitionsWithInitialOffsets,
+			watermarksPeriodic,
+			watermarksPunctuated,
+			processingTimeProvider,
+			autoWatermarkInterval,
+			userCodeClassLoader,
+			consumerMetricGroup,
+			useMetrics);
+
+		this.deserializer = new KafkaShuffleElementDeserializer<>();
+		this.serializer = serializer;
+		this.handover = new Handover();
+		this.consumerThread = new KafkaConsumerThread(
+			LOG,
+			handover,
+			kafkaProperties,
+			unassignedPartitionsQueue,
+			getFetcherName() + " for " + taskNameWithSubtasks,
+			pollTimeout,
+			useMetrics,
+			consumerMetricGroup,
+			subtaskMetricGroup);
+		this.watermarkHandler = new WatermarkHandler(producerParallelism);
+	}
+
+	// ------------------------------------------------------------------------
+	//  Fetcher work methods
+	// ------------------------------------------------------------------------
+
+	@Override
+	public void runFetchLoop() throws Exception {
+		try {
+			final Handover handover = this.handover;
+
+			// kick off the actual Kafka consumer
+			consumerThread.start();
+
+			while (running) {
+				// this blocks until we get the next records
+				// it automatically re-throws exceptions encountered in the consumer thread
+				final ConsumerRecords<byte[], byte[]> records = handover.pollNext();
+
+				// get the records for each topic partition
+				for (KafkaTopicPartitionState<TopicPartition> partition : subscribedPartitionStates()) {
+					List<ConsumerRecord<byte[], byte[]>> partitionRecords =
+						records.records(partition.getKafkaPartitionHandle());
+
+					for (ConsumerRecord<byte[], byte[]> record : partitionRecords) {
+						final KafkaShuffleElement<T> element = deserializer.deserialize(serializer, record);
+
+						// TODO: do we need to check the end of stream if reaching the end watermark?
+
+						if (element.isRecord()) {
+							// timestamp is inherent from upstream
+							// If using ProcessTime, timestamp is going to be ignored (upstream does not include timestamp as well)
+							// If using IngestionTime, timestamp is going to be overwritten
+							// If using EventTime, timestamp is going to be used
+							synchronized (checkpointLock) {
+								KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+								sourceContext.collectWithTimestamp(
+									elementAsRecord.value,
+									elementAsRecord.timestamp == null ? record.timestamp() : elementAsRecord.timestamp);
+								partition.setOffset(record.offset());
+							}
+						} else if (element.isWatermark()) {
+							final KafkaShuffleWatermark watermark = element.asWatermark();
+							Optional<Watermark> newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+							newWatermark.ifPresent(sourceContext::emitWatermark);
+						}
+					}
+				}
+			}
+		}
+		finally {
+			// this signals the consumer thread that no more work is to be done
+			consumerThread.shutdown();
+		}
+
+		// on a clean exit, wait for the runner thread
+		try {
+			consumerThread.join();
+		}
+		catch (InterruptedException e) {
+			// may be the result of a wake-up interruption after an exception.
+			// we ignore this here and only restore the interruption state
+			Thread.currentThread().interrupt();
+		}
+	}
+
+	@Override
+	public void cancel() {
+		// flag the main thread to exit. A thread interrupt will come anyways.
+		running = false;
+		handover.close();
+		consumerThread.shutdown();
+	}
+
+	@Override
+	protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition partition) {
+		return new TopicPartition(partition.getTopic(), partition.getPartition());
+	}
+
+	@Override
+	protected void doCommitInternalOffsetsToKafka(
+		Map<KafkaTopicPartition, Long> offsets,
+		@Nonnull KafkaCommitCallback commitCallback) throws Exception {
+
+		@SuppressWarnings("unchecked")
+		List<KafkaTopicPartitionState<TopicPartition>> partitions = subscribedPartitionStates();
+
+		Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new HashMap<>(partitions.size());
+
+		for (KafkaTopicPartitionState<TopicPartition> partition : partitions) {
+			Long lastProcessedOffset = offsets.get(partition.getKafkaTopicPartition());
+			if (lastProcessedOffset != null) {
+				checkState(lastProcessedOffset >= 0, "Illegal offset value to commit");
+
+				// committed offsets through the KafkaConsumer need to be 1 more than the last processed offset.
+				// This does not affect Flink's checkpoints/saved state.
+				long offsetToCommit = lastProcessedOffset + 1;
+
+				offsetsToCommit.put(partition.getKafkaPartitionHandle(), new OffsetAndMetadata(offsetToCommit));
+				partition.setCommittedOffset(offsetToCommit);
+			}
+		}
+
+		// record the work to be committed by the main consumer thread and make sure the consumer notices that
+		consumerThread.setOffsetsToCommit(offsetsToCommit, commitCallback);
+	}
+
+	private String getFetcherName() {
+		return "Kafka Shuffle Fetcher";
+	}
+
+	private abstract static class KafkaShuffleElement<T> {
+
+		public boolean isRecord() {
+			return getClass() == KafkaShuffleRecord.class;
+		}
+
+		boolean isWatermark() {
+			return getClass() == KafkaShuffleWatermark.class;
+		}
+
+		KafkaShuffleRecord<T> asRecord() {
+			return (KafkaShuffleRecord<T>) this;
+		}
+
+		KafkaShuffleWatermark asWatermark() {
+			return (KafkaShuffleWatermark) this;
+		}
+	}
+
+	private static class KafkaShuffleWatermark<T> extends KafkaShuffleElement<T> {
+		final int subtask;
+		final long watermark;
+
+		KafkaShuffleWatermark(int subtask, long watermark) {
+			this.subtask = subtask;
+			this.watermark = watermark;
+		}
+	}
+
+	private static class KafkaShuffleRecord<T> extends KafkaShuffleElement<T> {
+		final T value;
+		final Long timestamp;
+
+		KafkaShuffleRecord(T value) {
+			this.value = value;
+			this.timestamp = null;
+		}
+
+		KafkaShuffleRecord(long timestamp, T value) {
+			this.value = value;
+			this.timestamp = timestamp;
+		}
+	}
+

Review comment:
       `KafkaShuffleElement` seems over-engineered. I guess having a holder for timestamp + object is enough and then simply use `instanceof` checks.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> {
+	private static final Logger LOG = LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+	private final WatermarkHandler watermarkHandler;
+	// ------------------------------------------------------------------------
+
+	/** The schema to convert between Kafka's byte messages, and Flink's objects. */
+	private final KafkaShuffleElementDeserializer<T> deserializer;
+
+	/** Serializer to serialize record. */
+	private final TypeSerializer<T> serializer;
+
+	/** The handover of data and exceptions between the consumer thread and the task thread. */
+	private final Handover handover;
+
+	/** The thread that runs the actual KafkaConsumer and hand the record batches to this fetcher. */
+	private final KafkaConsumerThread consumerThread;
+
+	/** Flag to mark the main work loop as alive. */
+	private volatile boolean running = true;
+
+	public KafkaShuffleFetcher(
+		SourceFunction.SourceContext<T> sourceContext,

Review comment:
       nit: indent.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffleProducer.java
##########
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaException;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import org.apache.kafka.clients.producer.ProducerRecord;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffle.PARTITION_NUMBER;
+
+/**
+ * Flink Kafka Shuffle Producer Function.
+ * It is different from {@link FlinkKafkaProducer} in the way handling elements and watermarks
+ */
+@Internal
+public class FlinkKafkaShuffleProducer<IN, KEY> extends FlinkKafkaProducer<IN> {
+	private final KafkaSerializer<IN> kafkaSerializer;
+	private final KeySelector<IN, KEY> keySelector;
+	private final int numberOfPartitions;
+
+	FlinkKafkaShuffleProducer(
+		String defaultTopicId,
+		TypeInformationSerializationSchema<IN> schema,
+		Properties props,
+		KeySelector<IN, KEY> keySelector,
+		Semantic semantic,
+		int kafkaProducersPoolSize) {
+		super(defaultTopicId, (element, timestamp) -> null, props, semantic, kafkaProducersPoolSize);
+
+		this.kafkaSerializer = new KafkaSerializer<>(schema.getSerializer());
+		this.keySelector = keySelector;
+
+		Preconditions.checkArgument(
+			props.getProperty(PARTITION_NUMBER) != null,
+			"Missing partition number for Kafka Shuffle");
+		numberOfPartitions = PropertiesUtil.getInt(props, PARTITION_NUMBER, Integer.MIN_VALUE);
+	}
+
+	/**
+	 * This is the function invoked to handle each element.
+	 * @param transaction transaction state;
+	 *                    elements are written to Kafka in transactions to guarantee different level of data consistency
+	 * @param next element to handle
+	 * @param context context needed to handle the element
+	 * @throws FlinkKafkaException for kafka error
+	 */
+	@Override
+	public void invoke(KafkaTransactionState transaction, IN next, Context context) throws FlinkKafkaException {
+		checkErroneous();
+
+		// write timestamp to Kafka if timestamp is available
+		Long timestamp = context.timestamp();
+
+		int[] partitions = getPartitions(transaction);
+		int partitionIndex;
+		try {
+			partitionIndex = KeyGroupRangeAssignment
+				.assignKeyToParallelOperator(keySelector.getKey(next), partitions.length, partitions.length);
+		} catch (Exception e) {
+			throw new RuntimeException("Fail to assign a partition number to record");
+		}
+
+		ProducerRecord<byte[], byte[]> record = new ProducerRecord<>(

Review comment:
       Can you explain me once again, why we store timestamp directly in `ProducerRecord` and still also serialize it? Seems redundant.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.operators.Keys;
+import org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair together.
+ */
+class FlinkKafkaShuffle {
+	static final String PRODUCER_PARALLELISM = "producer parallelism";
+	static final String PARTITION_NUMBER = "partition number";
+
+	/**
+	 * Write to and read from a kafka shuffle with the partition decided by keys.
+	 * Consumers should read partitions equal to the key group indices they are assigned.
+	 * The number of partitions is the maximum parallelism of the receiving operator.
+	 * This version only supports numberOfPartitions = consumerParallelism.
+	 *
+	 * @param inputStream input stream to the kafka
+	 * @param topic kafka topic
+	 * @param producerParallelism parallelism of producer
+	 * @param numberOfPartitions number of partitions
+	 * @param properties Kafka properties
+	 * @param fields key positions from inputStream
+	 * @param <T> input type
+	 */
+	static <T> KeyedStream<T, Tuple> persistentKeyBy(

Review comment:
       `public`

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> {
+	private static final Logger LOG = LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+	private final WatermarkHandler watermarkHandler;
+	// ------------------------------------------------------------------------
+
+	/** The schema to convert between Kafka's byte messages, and Flink's objects. */
+	private final KafkaShuffleElementDeserializer<T> deserializer;
+
+	/** Serializer to serialize record. */
+	private final TypeSerializer<T> serializer;
+
+	/** The handover of data and exceptions between the consumer thread and the task thread. */
+	private final Handover handover;
+
+	/** The thread that runs the actual KafkaConsumer and hand the record batches to this fetcher. */
+	private final KafkaConsumerThread consumerThread;
+
+	/** Flag to mark the main work loop as alive. */
+	private volatile boolean running = true;
+
+	public KafkaShuffleFetcher(
+		SourceFunction.SourceContext<T> sourceContext,
+		Map<KafkaTopicPartition, Long> assignedPartitionsWithInitialOffsets,
+		SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
+		SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
+		ProcessingTimeService processingTimeProvider,
+		long autoWatermarkInterval,
+		ClassLoader userCodeClassLoader,
+		String taskNameWithSubtasks,
+		TypeSerializer<T> serializer,
+		Properties kafkaProperties,
+		long pollTimeout,
+		MetricGroup subtaskMetricGroup,
+		MetricGroup consumerMetricGroup,
+		boolean useMetrics,
+		int producerParallelism) throws Exception {
+		super(
+			sourceContext,
+			assignedPartitionsWithInitialOffsets,
+			watermarksPeriodic,
+			watermarksPunctuated,
+			processingTimeProvider,
+			autoWatermarkInterval,
+			userCodeClassLoader,
+			consumerMetricGroup,
+			useMetrics);
+
+		this.deserializer = new KafkaShuffleElementDeserializer<>();
+		this.serializer = serializer;
+		this.handover = new Handover();
+		this.consumerThread = new KafkaConsumerThread(
+			LOG,
+			handover,
+			kafkaProperties,
+			unassignedPartitionsQueue,
+			getFetcherName() + " for " + taskNameWithSubtasks,
+			pollTimeout,
+			useMetrics,
+			consumerMetricGroup,
+			subtaskMetricGroup);
+		this.watermarkHandler = new WatermarkHandler(producerParallelism);
+	}
+
+	// ------------------------------------------------------------------------
+	//  Fetcher work methods
+	// ------------------------------------------------------------------------
+
+	@Override
+	public void runFetchLoop() throws Exception {
+		try {
+			final Handover handover = this.handover;
+
+			// kick off the actual Kafka consumer
+			consumerThread.start();
+
+			while (running) {
+				// this blocks until we get the next records
+				// it automatically re-throws exceptions encountered in the consumer thread
+				final ConsumerRecords<byte[], byte[]> records = handover.pollNext();
+
+				// get the records for each topic partition
+				for (KafkaTopicPartitionState<TopicPartition> partition : subscribedPartitionStates()) {
+					List<ConsumerRecord<byte[], byte[]>> partitionRecords =
+						records.records(partition.getKafkaPartitionHandle());
+
+					for (ConsumerRecord<byte[], byte[]> record : partitionRecords) {
+						final KafkaShuffleElement<T> element = deserializer.deserialize(serializer, record);
+
+						// TODO: do we need to check the end of stream if reaching the end watermark?
+
+						if (element.isRecord()) {
+							// timestamp is inherent from upstream
+							// If using ProcessTime, timestamp is going to be ignored (upstream does not include timestamp as well)
+							// If using IngestionTime, timestamp is going to be overwritten
+							// If using EventTime, timestamp is going to be used
+							synchronized (checkpointLock) {
+								KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+								sourceContext.collectWithTimestamp(
+									elementAsRecord.value,
+									elementAsRecord.timestamp == null ? record.timestamp() : elementAsRecord.timestamp);
+								partition.setOffset(record.offset());
+							}
+						} else if (element.isWatermark()) {
+							final KafkaShuffleWatermark watermark = element.asWatermark();
+							Optional<Watermark> newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+							newWatermark.ifPresent(sourceContext::emitWatermark);

Review comment:
       Perform under checkpoint lock?

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.operators.Keys;
+import org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair together.
+ */
+class FlinkKafkaShuffle {

Review comment:
       If this is API, I guess it should be `public` and either `@PublicEvolving` or `@Experimental`.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> {
+	private static final Logger LOG = LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+	private final WatermarkHandler watermarkHandler;
+	// ------------------------------------------------------------------------
+
+	/** The schema to convert between Kafka's byte messages, and Flink's objects. */
+	private final KafkaShuffleElementDeserializer<T> deserializer;
+
+	/** Serializer to serialize record. */
+	private final TypeSerializer<T> serializer;
+
+	/** The handover of data and exceptions between the consumer thread and the task thread. */
+	private final Handover handover;
+
+	/** The thread that runs the actual KafkaConsumer and hand the record batches to this fetcher. */
+	private final KafkaConsumerThread consumerThread;
+
+	/** Flag to mark the main work loop as alive. */
+	private volatile boolean running = true;
+
+	public KafkaShuffleFetcher(
+		SourceFunction.SourceContext<T> sourceContext,
+		Map<KafkaTopicPartition, Long> assignedPartitionsWithInitialOffsets,
+		SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
+		SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
+		ProcessingTimeService processingTimeProvider,
+		long autoWatermarkInterval,
+		ClassLoader userCodeClassLoader,
+		String taskNameWithSubtasks,
+		TypeSerializer<T> serializer,
+		Properties kafkaProperties,
+		long pollTimeout,
+		MetricGroup subtaskMetricGroup,
+		MetricGroup consumerMetricGroup,
+		boolean useMetrics,
+		int producerParallelism) throws Exception {
+		super(
+			sourceContext,
+			assignedPartitionsWithInitialOffsets,
+			watermarksPeriodic,
+			watermarksPunctuated,
+			processingTimeProvider,
+			autoWatermarkInterval,
+			userCodeClassLoader,
+			consumerMetricGroup,
+			useMetrics);
+
+		this.deserializer = new KafkaShuffleElementDeserializer<>();
+		this.serializer = serializer;
+		this.handover = new Handover();
+		this.consumerThread = new KafkaConsumerThread(
+			LOG,
+			handover,
+			kafkaProperties,
+			unassignedPartitionsQueue,
+			getFetcherName() + " for " + taskNameWithSubtasks,
+			pollTimeout,
+			useMetrics,
+			consumerMetricGroup,
+			subtaskMetricGroup);
+		this.watermarkHandler = new WatermarkHandler(producerParallelism);
+	}
+
+	// ------------------------------------------------------------------------
+	//  Fetcher work methods
+	// ------------------------------------------------------------------------
+
+	@Override
+	public void runFetchLoop() throws Exception {
+		try {
+			final Handover handover = this.handover;
+
+			// kick off the actual Kafka consumer
+			consumerThread.start();
+
+			while (running) {
+				// this blocks until we get the next records
+				// it automatically re-throws exceptions encountered in the consumer thread
+				final ConsumerRecords<byte[], byte[]> records = handover.pollNext();
+
+				// get the records for each topic partition
+				for (KafkaTopicPartitionState<TopicPartition> partition : subscribedPartitionStates()) {
+					List<ConsumerRecord<byte[], byte[]>> partitionRecords =
+						records.records(partition.getKafkaPartitionHandle());
+
+					for (ConsumerRecord<byte[], byte[]> record : partitionRecords) {
+						final KafkaShuffleElement<T> element = deserializer.deserialize(serializer, record);
+
+						// TODO: do we need to check the end of stream if reaching the end watermark?
+
+						if (element.isRecord()) {
+							// timestamp is inherent from upstream
+							// If using ProcessTime, timestamp is going to be ignored (upstream does not include timestamp as well)
+							// If using IngestionTime, timestamp is going to be overwritten
+							// If using EventTime, timestamp is going to be used
+							synchronized (checkpointLock) {
+								KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+								sourceContext.collectWithTimestamp(
+									elementAsRecord.value,
+									elementAsRecord.timestamp == null ? record.timestamp() : elementAsRecord.timestamp);
+								partition.setOffset(record.offset());
+							}
+						} else if (element.isWatermark()) {
+							final KafkaShuffleWatermark watermark = element.asWatermark();
+							Optional<Watermark> newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+							newWatermark.ifPresent(sourceContext::emitWatermark);
+						}
+					}
+				}
+			}
+		}
+		finally {
+			// this signals the consumer thread that no more work is to be done
+			consumerThread.shutdown();
+		}
+
+		// on a clean exit, wait for the runner thread
+		try {
+			consumerThread.join();
+		}
+		catch (InterruptedException e) {
+			// may be the result of a wake-up interruption after an exception.
+			// we ignore this here and only restore the interruption state
+			Thread.currentThread().interrupt();
+		}
+	}
+
+	@Override
+	public void cancel() {
+		// flag the main thread to exit. A thread interrupt will come anyways.
+		running = false;
+		handover.close();
+		consumerThread.shutdown();
+	}
+
+	@Override
+	protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition partition) {
+		return new TopicPartition(partition.getTopic(), partition.getPartition());
+	}
+
+	@Override
+	protected void doCommitInternalOffsetsToKafka(
+		Map<KafkaTopicPartition, Long> offsets,
+		@Nonnull KafkaCommitCallback commitCallback) throws Exception {
+
+		@SuppressWarnings("unchecked")
+		List<KafkaTopicPartitionState<TopicPartition>> partitions = subscribedPartitionStates();
+
+		Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new HashMap<>(partitions.size());
+
+		for (KafkaTopicPartitionState<TopicPartition> partition : partitions) {
+			Long lastProcessedOffset = offsets.get(partition.getKafkaTopicPartition());
+			if (lastProcessedOffset != null) {
+				checkState(lastProcessedOffset >= 0, "Illegal offset value to commit");
+
+				// committed offsets through the KafkaConsumer need to be 1 more than the last processed offset.
+				// This does not affect Flink's checkpoints/saved state.
+				long offsetToCommit = lastProcessedOffset + 1;
+
+				offsetsToCommit.put(partition.getKafkaPartitionHandle(), new OffsetAndMetadata(offsetToCommit));
+				partition.setCommittedOffset(offsetToCommit);
+			}
+		}
+
+		// record the work to be committed by the main consumer thread and make sure the consumer notices that
+		consumerThread.setOffsetsToCommit(offsetsToCommit, commitCallback);
+	}
+
+	private String getFetcherName() {
+		return "Kafka Shuffle Fetcher";
+	}
+
+	private abstract static class KafkaShuffleElement<T> {
+
+		public boolean isRecord() {
+			return getClass() == KafkaShuffleRecord.class;
+		}
+
+		boolean isWatermark() {
+			return getClass() == KafkaShuffleWatermark.class;
+		}
+
+		KafkaShuffleRecord<T> asRecord() {
+			return (KafkaShuffleRecord<T>) this;
+		}
+
+		KafkaShuffleWatermark asWatermark() {
+			return (KafkaShuffleWatermark) this;
+		}
+	}
+
+	private static class KafkaShuffleWatermark<T> extends KafkaShuffleElement<T> {
+		final int subtask;
+		final long watermark;
+
+		KafkaShuffleWatermark(int subtask, long watermark) {
+			this.subtask = subtask;
+			this.watermark = watermark;
+		}
+	}
+
+	private static class KafkaShuffleRecord<T> extends KafkaShuffleElement<T> {
+		final T value;
+		final Long timestamp;
+
+		KafkaShuffleRecord(T value) {
+			this.value = value;
+			this.timestamp = null;
+		}
+
+		KafkaShuffleRecord(long timestamp, T value) {
+			this.value = value;
+			this.timestamp = timestamp;
+		}
+	}
+
+	private static class KafkaShuffleElementDeserializer<T> implements Serializable {
+		private transient DataInputDeserializer dis;
+
+		KafkaShuffleElementDeserializer() {
+			this.dis = new DataInputDeserializer();
+		}
+
+		KafkaShuffleElement<T> deserialize(TypeSerializer<T> serializer, ConsumerRecord<byte[], byte[]> record)
+			throws Exception {
+			byte[] value = record.value();
+			dis.setBuffer(value);
+			int tag = IntSerializer.INSTANCE.deserialize(dis);
+
+			if (tag == TAG_REC_WITHOUT_TIMESTAMP) {
+				return new KafkaShuffleRecord<>(serializer.deserialize(dis));
+			} else if (tag == TAG_REC_WITH_TIMESTAMP) {
+				return new KafkaShuffleRecord<>(LongSerializer.INSTANCE.deserialize(dis), serializer.deserialize(dis));

Review comment:
       Again, why do we serialize timestamp in the payload and not take it from `ConsumerRecord`?

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.operators.Keys;
+import org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair together.
+ */
+class FlinkKafkaShuffle {
+	static final String PRODUCER_PARALLELISM = "producer parallelism";
+	static final String PARTITION_NUMBER = "partition number";
+
+	/**
+	 * Write to and read from a kafka shuffle with the partition decided by keys.
+	 * Consumers should read partitions equal to the key group indices they are assigned.
+	 * The number of partitions is the maximum parallelism of the receiving operator.
+	 * This version only supports numberOfPartitions = consumerParallelism.
+	 *
+	 * @param inputStream input stream to the kafka
+	 * @param topic kafka topic
+	 * @param producerParallelism parallelism of producer
+	 * @param numberOfPartitions number of partitions

Review comment:
       Shouldn't that be the same?

##########
File path: flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
##########
@@ -264,7 +264,7 @@ public FlinkKafkaConsumerBase(
 	 * @param properties - Kafka configuration properties to be adjusted
 	 * @param offsetCommitMode offset commit mode
 	 */
-	static void adjustAutoCommitConfig(Properties properties, OffsetCommitMode offsetCommitMode) {
+	public static void adjustAutoCommitConfig(Properties properties, OffsetCommitMode offsetCommitMode) {

Review comment:
       protected?

##########
File path: flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+	@BeforeClass
+	public static void prepare() throws Exception {
+		KafkaProducerTestBase.prepare();
+		((KafkaTestEnvironmentImpl) kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with the default time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleProcessingTime() throws Exception {
+		simpleEndToEndTest("test_simple_processing_time", 100000, ProcessingTime);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with time characteristic: IngestionTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleIngestionTime() throws Exception {
+		simpleEndToEndTest("test_simple_ingestion_time", 100000, IngestionTime);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with time characteristic: EventTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleEventTime() throws Exception {
+		simpleEndToEndTest("test_simple_event_time", 100000, EventTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionProcessingTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_processing_time", 100000, ProcessingTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: IngestionTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionIngestionTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_ingestion_time", 100000, IngestionTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: EventTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionEventTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_event_time", 100000, EventTime);
+	}
+
+	/**
+	 * Schema: (key, timestamp, source instance Id).
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1
+	 * To test no data is lost or duplicated end-2-end
+	 */
+	private void simpleEndToEndTest(String topic, int elementCount, TimeCharacteristic timeCharacteristic)
+		throws Exception {
+		final int numberOfPartitions = 1;
+		final int producerParallelism = 1;
+
+		createTestTopic(topic, numberOfPartitions, 1);
+
+		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(producerParallelism);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+		env.setStreamTimeCharacteristic(timeCharacteristic);
+
+		DataStream<Tuple3<Integer, Long, String>> source =
+			env.addSource(new KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+		DataStream<Tuple3<Integer, Long, String>> input = (timeCharacteristic == EventTime) ?
+			source.assignTimestampsAndWatermarks(new PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+		Properties properties = kafkaServer.getStandardProperties();
+		FlinkKafkaShuffle
+			.persistentKeyBy(input, topic, producerParallelism, numberOfPartitions, properties, 0)
+			.map(new ElementCountNoMoreThanValidator(elementCount * producerParallelism)).setParallelism(1)
+			.map(new ElementCountNoLessThanValidator(elementCount * producerParallelism)).setParallelism(1);
+
+		tryExecute(env, topic);
+
+		deleteTestTopic(topic);
+	}
+
+	/**
+	 * Schema: (key, timestamp, source instance Id).
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3
+	 * To test data is partitioned to the right partition
+	 */
+	private void testAssignedToPartition(String topic, int elementCount, TimeCharacteristic timeCharacteristic)
+		throws Exception {
+		final int numberOfPartitions = 3;
+		final int producerParallelism = 2;
+
+		createTestTopic(topic, numberOfPartitions, 1);
+
+		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(producerParallelism);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+		env.setStreamTimeCharacteristic(EventTime);
+
+		DataStream<Tuple3<Integer, Long, String>> source =
+			env.addSource(new KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+		DataStream<Tuple3<Integer, Long, String>> input = (timeCharacteristic == EventTime) ?
+			source.assignTimestampsAndWatermarks(new PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+		// ------- Write data to Kafka partition basesd on FlinkKafkaPartitioner ------
+		Properties properties = kafkaServer.getStandardProperties();
+
+		KeyedStream<Tuple3<Integer, Long, String>, Tuple> keyedStream = FlinkKafkaShuffle
+			.persistentKeyBy(input, topic, producerParallelism, numberOfPartitions, properties, 0);
+		keyedStream
+			.process(new PartitionValidator(keyedStream.getKeySelector(), numberOfPartitions, topic))
+			.setParallelism(numberOfPartitions)
+			.map(new ElementCountNoMoreThanValidator(elementCount * producerParallelism)).setParallelism(1)
+			.map(new ElementCountNoLessThanValidator(elementCount * producerParallelism)).setParallelism(1);
+
+		tryExecute(env, "KafkaShuffle partition assignment test");
+
+		deleteTestTopic(topic);

Review comment:
       Extract to avoid duplicate code with `simpleEndToEndTest`.

##########
File path: flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+	@BeforeClass
+	public static void prepare() throws Exception {
+		KafkaProducerTestBase.prepare();
+		((KafkaTestEnvironmentImpl) kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with the default time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)

Review comment:
       Instead of setting timeout to all methods, I'd go with a JUnit rule:
   ```
   	@Rule
   	public final Timeout timeout = Timeout.builder()
   			.withTimeout(30, TimeUnit.SECONDS)
   			.build();
   ```
   
   and then only use `@Test` on the tests. That's easier to maintain when we need to increase the timeout on azure.

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/shuffle/FlinkKafkaShuffle.java
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.operators.Keys;
+import org.apache.flink.api.common.serialization.TypeInformationSerializationSchema;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.DataStreamUtils;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.transformations.SinkTransformation;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.util.keys.KeySelectorUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.PropertiesUtil;
+
+import java.util.Properties;
+
+/**
+ * Use Kafka as a persistent shuffle by wrapping a Kafka Source/Sink pair together.
+ */
+class FlinkKafkaShuffle {
+	static final String PRODUCER_PARALLELISM = "producer parallelism";
+	static final String PARTITION_NUMBER = "partition number";
+
+	/**
+	 * Write to and read from a kafka shuffle with the partition decided by keys.
+	 * Consumers should read partitions equal to the key group indices they are assigned.
+	 * The number of partitions is the maximum parallelism of the receiving operator.
+	 * This version only supports numberOfPartitions = consumerParallelism.
+	 *
+	 * @param inputStream input stream to the kafka
+	 * @param topic kafka topic
+	 * @param producerParallelism parallelism of producer
+	 * @param numberOfPartitions number of partitions
+	 * @param properties Kafka properties
+	 * @param fields key positions from inputStream
+	 * @param <T> input type
+	 */
+	static <T> KeyedStream<T, Tuple> persistentKeyBy(
+		DataStream<T> inputStream,

Review comment:
       nit: indent.

##########
File path: flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {

Review comment:
       The implemented tests are really good. I miss two cases though:
   * Out of order events (add randomness to source timestamp)
   * Any failure and recovery tests. See https://github.com/apache/flink/blob/f239d680e9b8f3f5ace621b7806e0bb7e14d3fdd/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java for a possible approach.

##########
File path: flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+	@BeforeClass
+	public static void prepare() throws Exception {
+		KafkaProducerTestBase.prepare();
+		((KafkaTestEnvironmentImpl) kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with the default time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleProcessingTime() throws Exception {
+		simpleEndToEndTest("test_simple_processing_time", 100000, ProcessingTime);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with time characteristic: IngestionTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleIngestionTime() throws Exception {
+		simpleEndToEndTest("test_simple_ingestion_time", 100000, IngestionTime);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with time characteristic: EventTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleEventTime() throws Exception {
+		simpleEndToEndTest("test_simple_event_time", 100000, EventTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionProcessingTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_processing_time", 100000, ProcessingTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: IngestionTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionIngestionTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_ingestion_time", 100000, IngestionTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: EventTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionEventTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_event_time", 100000, EventTime);
+	}
+
+	/**
+	 * Schema: (key, timestamp, source instance Id).
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1
+	 * To test no data is lost or duplicated end-2-end
+	 */
+	private void simpleEndToEndTest(String topic, int elementCount, TimeCharacteristic timeCharacteristic)
+		throws Exception {
+		final int numberOfPartitions = 1;
+		final int producerParallelism = 1;
+
+		createTestTopic(topic, numberOfPartitions, 1);
+
+		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(producerParallelism);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+		env.setStreamTimeCharacteristic(timeCharacteristic);
+
+		DataStream<Tuple3<Integer, Long, String>> source =
+			env.addSource(new KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+		DataStream<Tuple3<Integer, Long, String>> input = (timeCharacteristic == EventTime) ?
+			source.assignTimestampsAndWatermarks(new PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+		Properties properties = kafkaServer.getStandardProperties();
+		FlinkKafkaShuffle
+			.persistentKeyBy(input, topic, producerParallelism, numberOfPartitions, properties, 0)
+			.map(new ElementCountNoMoreThanValidator(elementCount * producerParallelism)).setParallelism(1)
+			.map(new ElementCountNoLessThanValidator(elementCount * producerParallelism)).setParallelism(1);
+
+		tryExecute(env, topic);
+
+		deleteTestTopic(topic);
+	}
+
+	/**
+	 * Schema: (key, timestamp, source instance Id).
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3
+	 * To test data is partitioned to the right partition
+	 */
+	private void testAssignedToPartition(String topic, int elementCount, TimeCharacteristic timeCharacteristic)
+		throws Exception {
+		final int numberOfPartitions = 3;
+		final int producerParallelism = 2;
+
+		createTestTopic(topic, numberOfPartitions, 1);
+
+		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(producerParallelism);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+		env.setStreamTimeCharacteristic(EventTime);
+
+		DataStream<Tuple3<Integer, Long, String>> source =
+			env.addSource(new KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+		DataStream<Tuple3<Integer, Long, String>> input = (timeCharacteristic == EventTime) ?
+			source.assignTimestampsAndWatermarks(new PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+		// ------- Write data to Kafka partition basesd on FlinkKafkaPartitioner ------
+		Properties properties = kafkaServer.getStandardProperties();
+
+		KeyedStream<Tuple3<Integer, Long, String>, Tuple> keyedStream = FlinkKafkaShuffle
+			.persistentKeyBy(input, topic, producerParallelism, numberOfPartitions, properties, 0);
+		keyedStream
+			.process(new PartitionValidator(keyedStream.getKeySelector(), numberOfPartitions, topic))
+			.setParallelism(numberOfPartitions)
+			.map(new ElementCountNoMoreThanValidator(elementCount * producerParallelism)).setParallelism(1)
+			.map(new ElementCountNoLessThanValidator(elementCount * producerParallelism)).setParallelism(1);
+
+		tryExecute(env, "KafkaShuffle partition assignment test");
+
+		deleteTestTopic(topic);
+	}
+
+	private static class PunctuatedExtractor implements AssignerWithPunctuatedWatermarks<Tuple3<Integer, Long, String>> {
+		private static final long serialVersionUID = 1L;
+
+		@Override
+		public long extractTimestamp(Tuple3<Integer, Long, String> element, long previousTimestamp) {
+			return element.f1;
+		}
+
+		@Override
+		public Watermark checkAndGetNextWatermark(Tuple3<Integer, Long, String> lastElement, long extractedTimestamp) {
+			return new Watermark(extractedTimestamp);
+		}
+	}
+
+	private static class KafkaSourceFunction extends RichParallelSourceFunction<Tuple3<Integer, Long, String>> {
+		private volatile boolean running = true;
+		private int elementCount;
+
+		KafkaSourceFunction(int elementCount) {
+			this.elementCount = elementCount;
+		}
+
+		@Override
+		public void run(SourceContext<Tuple3<Integer, Long, String>> ctx) {
+			long timestamp = 1584349939799L;
+			int instanceId = getRuntimeContext().getIndexOfThisSubtask();
+			for (int i = 0; i < elementCount && running; i++) {
+				ctx.collect(new Tuple3<>(i, timestamp++, "source-instance-" + instanceId));
+			}
+		}
+
+		@Override
+		public void cancel() {
+			running = false;
+		}
+	}
+
+	private static class ElementCountNoMoreThanValidator
+		implements MapFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> {

Review comment:
       nit: also double-indent.

##########
File path: flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+	@BeforeClass
+	public static void prepare() throws Exception {
+		KafkaProducerTestBase.prepare();
+		((KafkaTestEnvironmentImpl) kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with the default time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleProcessingTime() throws Exception {
+		simpleEndToEndTest("test_simple_processing_time", 100000, ProcessingTime);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with time characteristic: IngestionTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleIngestionTime() throws Exception {
+		simpleEndToEndTest("test_simple_ingestion_time", 100000, IngestionTime);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with time characteristic: EventTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleEventTime() throws Exception {
+		simpleEndToEndTest("test_simple_event_time", 100000, EventTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionProcessingTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_processing_time", 100000, ProcessingTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: IngestionTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionIngestionTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_ingestion_time", 100000, IngestionTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: EventTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionEventTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_event_time", 100000, EventTime);
+	}
+
+	/**
+	 * Schema: (key, timestamp, source instance Id).
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1
+	 * To test no data is lost or duplicated end-2-end
+	 */
+	private void simpleEndToEndTest(String topic, int elementCount, TimeCharacteristic timeCharacteristic)

Review comment:
       We use end2end in a different context, where we use a complete Flink distribution to execute the test. 
   
   I'd simply call it `testKafkaShuffle` to avoid any misunderstanding.

##########
File path: flink-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/shuffle/KafkaShuffleITCase.java
##########
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.shuffle;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducer;
+import org.apache.flink.streaming.connectors.kafka.KafkaConsumerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaProducerTestBase;
+import org.apache.flink.streaming.connectors.kafka.KafkaTestEnvironmentImpl;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionAssigner;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.Properties;
+
+import static org.apache.flink.streaming.api.TimeCharacteristic.EventTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.IngestionTime;
+import static org.apache.flink.streaming.api.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.test.util.TestUtils.tryExecute;
+
+/**
+ * Simple End to End Test for Kafka.
+ */
+public class KafkaShuffleITCase extends KafkaConsumerTestBase {
+
+	@BeforeClass
+	public static void prepare() throws Exception {
+		KafkaProducerTestBase.prepare();
+		((KafkaTestEnvironmentImpl) kafkaServer).setProducerSemantic(FlinkKafkaProducer.Semantic.AT_LEAST_ONCE);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with the default time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleProcessingTime() throws Exception {
+		simpleEndToEndTest("test_simple_processing_time", 100000, ProcessingTime);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with time characteristic: IngestionTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleIngestionTime() throws Exception {
+		simpleEndToEndTest("test_simple_ingestion_time", 100000, IngestionTime);
+	}
+
+	/**
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1.
+	 * To test no data is lost or duplicated end-2-end with time characteristic: EventTime
+	 */
+	@Test(timeout = 30000L)
+	public void testSimpleEventTime() throws Exception {
+		simpleEndToEndTest("test_simple_event_time", 100000, EventTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: ProcessingTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionProcessingTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_processing_time", 100000, ProcessingTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: IngestionTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionIngestionTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_ingestion_time", 100000, IngestionTime);
+	}
+
+	/**
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3.
+	 * To test data is partitioned to the right partition with time characteristic: EventTime
+	 */
+	@Test(timeout = 30000L)
+	public void testAssignedToPartitionEventTime() throws Exception {
+		testAssignedToPartition("test_assigned_to_partition_event_time", 100000, EventTime);
+	}
+
+	/**
+	 * Schema: (key, timestamp, source instance Id).
+	 * Producer Parallelism = 1; Kafka Partition # = 1; Consumer Parallelism = 1
+	 * To test no data is lost or duplicated end-2-end
+	 */
+	private void simpleEndToEndTest(String topic, int elementCount, TimeCharacteristic timeCharacteristic)
+		throws Exception {
+		final int numberOfPartitions = 1;
+		final int producerParallelism = 1;
+
+		createTestTopic(topic, numberOfPartitions, 1);
+
+		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(producerParallelism);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+		env.setStreamTimeCharacteristic(timeCharacteristic);
+
+		DataStream<Tuple3<Integer, Long, String>> source =
+			env.addSource(new KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+		DataStream<Tuple3<Integer, Long, String>> input = (timeCharacteristic == EventTime) ?
+			source.assignTimestampsAndWatermarks(new PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+		Properties properties = kafkaServer.getStandardProperties();
+		FlinkKafkaShuffle
+			.persistentKeyBy(input, topic, producerParallelism, numberOfPartitions, properties, 0)
+			.map(new ElementCountNoMoreThanValidator(elementCount * producerParallelism)).setParallelism(1)
+			.map(new ElementCountNoLessThanValidator(elementCount * producerParallelism)).setParallelism(1);
+
+		tryExecute(env, topic);
+
+		deleteTestTopic(topic);
+	}
+
+	/**
+	 * Schema: (key, timestamp, source instance Id).
+	 * Producer Parallelism = 2; Kafka Partition # = 3; Consumer Parallelism = 3
+	 * To test data is partitioned to the right partition
+	 */
+	private void testAssignedToPartition(String topic, int elementCount, TimeCharacteristic timeCharacteristic)
+		throws Exception {
+		final int numberOfPartitions = 3;
+		final int producerParallelism = 2;
+
+		createTestTopic(topic, numberOfPartitions, 1);
+
+		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(producerParallelism);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+		env.setStreamTimeCharacteristic(EventTime);
+
+		DataStream<Tuple3<Integer, Long, String>> source =
+			env.addSource(new KafkaSourceFunction(elementCount)).setParallelism(producerParallelism);
+
+		DataStream<Tuple3<Integer, Long, String>> input = (timeCharacteristic == EventTime) ?
+			source.assignTimestampsAndWatermarks(new PunctuatedExtractor()).setParallelism(producerParallelism) : source;
+
+		// ------- Write data to Kafka partition basesd on FlinkKafkaPartitioner ------
+		Properties properties = kafkaServer.getStandardProperties();
+
+		KeyedStream<Tuple3<Integer, Long, String>, Tuple> keyedStream = FlinkKafkaShuffle
+			.persistentKeyBy(input, topic, producerParallelism, numberOfPartitions, properties, 0);
+		keyedStream
+			.process(new PartitionValidator(keyedStream.getKeySelector(), numberOfPartitions, topic))
+			.setParallelism(numberOfPartitions)
+			.map(new ElementCountNoMoreThanValidator(elementCount * producerParallelism)).setParallelism(1)
+			.map(new ElementCountNoLessThanValidator(elementCount * producerParallelism)).setParallelism(1);
+
+		tryExecute(env, "KafkaShuffle partition assignment test");
+
+		deleteTestTopic(topic);
+	}
+
+	private static class PunctuatedExtractor implements AssignerWithPunctuatedWatermarks<Tuple3<Integer, Long, String>> {
+		private static final long serialVersionUID = 1L;
+
+		@Override
+		public long extractTimestamp(Tuple3<Integer, Long, String> element, long previousTimestamp) {
+			return element.f1;
+		}
+
+		@Override
+		public Watermark checkAndGetNextWatermark(Tuple3<Integer, Long, String> lastElement, long extractedTimestamp) {
+			return new Watermark(extractedTimestamp);
+		}
+	}
+
+	private static class KafkaSourceFunction extends RichParallelSourceFunction<Tuple3<Integer, Long, String>> {
+		private volatile boolean running = true;
+		private int elementCount;
+
+		KafkaSourceFunction(int elementCount) {
+			this.elementCount = elementCount;
+		}
+
+		@Override
+		public void run(SourceContext<Tuple3<Integer, Long, String>> ctx) {
+			long timestamp = 1584349939799L;
+			int instanceId = getRuntimeContext().getIndexOfThisSubtask();
+			for (int i = 0; i < elementCount && running; i++) {
+				ctx.collect(new Tuple3<>(i, timestamp++, "source-instance-" + instanceId));
+			}
+		}
+
+		@Override
+		public void cancel() {
+			running = false;
+		}
+	}
+
+	private static class ElementCountNoMoreThanValidator
+		implements MapFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> {
+		private final int totalCount;
+		private int counter = 0;
+
+		ElementCountNoMoreThanValidator(int totalCount) {
+			this.totalCount = totalCount;
+		}
+
+		@Override
+		public Tuple3<Integer, Long, String> map(Tuple3<Integer, Long, String> element) throws Exception {
+			counter++;
+
+			if (counter > totalCount) {
+				throw new Exception("Error: number of elements more than expected");
+			}
+
+			return element;
+		}
+	}
+
+	private static class ElementCountNoLessThanValidator
+		implements MapFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> {
+		private final int totalCount;
+		private int counter = 0;
+
+		ElementCountNoLessThanValidator(int totalCount) {
+			this.totalCount = totalCount;
+		}
+
+		@Override
+		public Tuple3<Integer, Long, String> map(Tuple3<Integer, Long, String> element) throws Exception {
+			counter++;
+
+			if (counter == totalCount) {
+				throw new SuccessException();
+			}
+
+			return element;
+		}
+	}
+
+	private static class PartitionValidator
+		extends KeyedProcessFunction<Tuple, Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> {
+
+		private final KeySelector<Tuple3<Integer, Long, String>, Tuple> keySelector;
+		private final int numberOfPartitions;
+		private final String topic;
+
+		private int previousPartition;
+
+		PartitionValidator(
+			KeySelector<Tuple3<Integer, Long, String>, Tuple> keySelector, int numberOfPartitions, String topic) {
+			this.keySelector = keySelector;
+			this.numberOfPartitions = numberOfPartitions;
+			this.topic = topic;
+			this.previousPartition = -1;
+		}
+
+		@Override
+		public void processElement(
+			Tuple3<Integer, Long, String> in, Context ctx, Collector<Tuple3<Integer, Long, String>> out)

Review comment:
       nit: chop args

##########
File path: flink-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/KafkaShuffleFetcher.java
##########
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.internal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
+import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaCommitCallback;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.common.TopicPartition;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Properties;
+
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITHOUT_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_REC_WITH_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kafka.shuffle.FlinkKafkaShuffleProducer.KafkaSerializer.TAG_WATERMARK;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Fetch data from Kafka for Kafka Shuffle.
+ */
+@Internal
+public class KafkaShuffleFetcher<T> extends AbstractFetcher<T, TopicPartition> {
+	private static final Logger LOG = LoggerFactory.getLogger(KafkaShuffleFetcher.class);
+
+	private final WatermarkHandler watermarkHandler;
+	// ------------------------------------------------------------------------
+
+	/** The schema to convert between Kafka's byte messages, and Flink's objects. */
+	private final KafkaShuffleElementDeserializer<T> deserializer;
+
+	/** Serializer to serialize record. */
+	private final TypeSerializer<T> serializer;
+
+	/** The handover of data and exceptions between the consumer thread and the task thread. */
+	private final Handover handover;
+
+	/** The thread that runs the actual KafkaConsumer and hand the record batches to this fetcher. */
+	private final KafkaConsumerThread consumerThread;
+
+	/** Flag to mark the main work loop as alive. */
+	private volatile boolean running = true;
+
+	public KafkaShuffleFetcher(
+		SourceFunction.SourceContext<T> sourceContext,
+		Map<KafkaTopicPartition, Long> assignedPartitionsWithInitialOffsets,
+		SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
+		SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
+		ProcessingTimeService processingTimeProvider,
+		long autoWatermarkInterval,
+		ClassLoader userCodeClassLoader,
+		String taskNameWithSubtasks,
+		TypeSerializer<T> serializer,
+		Properties kafkaProperties,
+		long pollTimeout,
+		MetricGroup subtaskMetricGroup,
+		MetricGroup consumerMetricGroup,
+		boolean useMetrics,
+		int producerParallelism) throws Exception {
+		super(
+			sourceContext,
+			assignedPartitionsWithInitialOffsets,
+			watermarksPeriodic,
+			watermarksPunctuated,
+			processingTimeProvider,
+			autoWatermarkInterval,
+			userCodeClassLoader,
+			consumerMetricGroup,
+			useMetrics);
+
+		this.deserializer = new KafkaShuffleElementDeserializer<>();
+		this.serializer = serializer;
+		this.handover = new Handover();
+		this.consumerThread = new KafkaConsumerThread(
+			LOG,
+			handover,
+			kafkaProperties,
+			unassignedPartitionsQueue,
+			getFetcherName() + " for " + taskNameWithSubtasks,
+			pollTimeout,
+			useMetrics,
+			consumerMetricGroup,
+			subtaskMetricGroup);
+		this.watermarkHandler = new WatermarkHandler(producerParallelism);
+	}
+
+	// ------------------------------------------------------------------------
+	//  Fetcher work methods
+	// ------------------------------------------------------------------------
+
+	@Override
+	public void runFetchLoop() throws Exception {
+		try {
+			final Handover handover = this.handover;
+
+			// kick off the actual Kafka consumer
+			consumerThread.start();
+
+			while (running) {
+				// this blocks until we get the next records
+				// it automatically re-throws exceptions encountered in the consumer thread
+				final ConsumerRecords<byte[], byte[]> records = handover.pollNext();
+
+				// get the records for each topic partition
+				for (KafkaTopicPartitionState<TopicPartition> partition : subscribedPartitionStates()) {
+					List<ConsumerRecord<byte[], byte[]>> partitionRecords =
+						records.records(partition.getKafkaPartitionHandle());
+
+					for (ConsumerRecord<byte[], byte[]> record : partitionRecords) {
+						final KafkaShuffleElement<T> element = deserializer.deserialize(serializer, record);
+
+						// TODO: do we need to check the end of stream if reaching the end watermark?
+
+						if (element.isRecord()) {
+							// timestamp is inherent from upstream
+							// If using ProcessTime, timestamp is going to be ignored (upstream does not include timestamp as well)
+							// If using IngestionTime, timestamp is going to be overwritten
+							// If using EventTime, timestamp is going to be used
+							synchronized (checkpointLock) {
+								KafkaShuffleRecord<T> elementAsRecord = element.asRecord();
+								sourceContext.collectWithTimestamp(
+									elementAsRecord.value,
+									elementAsRecord.timestamp == null ? record.timestamp() : elementAsRecord.timestamp);
+								partition.setOffset(record.offset());
+							}
+						} else if (element.isWatermark()) {
+							final KafkaShuffleWatermark watermark = element.asWatermark();
+							Optional<Watermark> newWatermark = watermarkHandler.checkAndGetNewWatermark(watermark);
+							newWatermark.ifPresent(sourceContext::emitWatermark);
+						}
+					}
+				}
+			}
+		}
+		finally {
+			// this signals the consumer thread that no more work is to be done
+			consumerThread.shutdown();
+		}
+
+		// on a clean exit, wait for the runner thread
+		try {
+			consumerThread.join();
+		}
+		catch (InterruptedException e) {
+			// may be the result of a wake-up interruption after an exception.
+			// we ignore this here and only restore the interruption state
+			Thread.currentThread().interrupt();
+		}
+	}
+
+	@Override
+	public void cancel() {
+		// flag the main thread to exit. A thread interrupt will come anyways.
+		running = false;
+		handover.close();
+		consumerThread.shutdown();
+	}
+
+	@Override
+	protected TopicPartition createKafkaPartitionHandle(KafkaTopicPartition partition) {
+		return new TopicPartition(partition.getTopic(), partition.getPartition());
+	}
+
+	@Override
+	protected void doCommitInternalOffsetsToKafka(
+		Map<KafkaTopicPartition, Long> offsets,
+		@Nonnull KafkaCommitCallback commitCallback) throws Exception {
+
+		@SuppressWarnings("unchecked")
+		List<KafkaTopicPartitionState<TopicPartition>> partitions = subscribedPartitionStates();
+
+		Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new HashMap<>(partitions.size());
+
+		for (KafkaTopicPartitionState<TopicPartition> partition : partitions) {
+			Long lastProcessedOffset = offsets.get(partition.getKafkaTopicPartition());
+			if (lastProcessedOffset != null) {
+				checkState(lastProcessedOffset >= 0, "Illegal offset value to commit");
+
+				// committed offsets through the KafkaConsumer need to be 1 more than the last processed offset.
+				// This does not affect Flink's checkpoints/saved state.
+				long offsetToCommit = lastProcessedOffset + 1;
+
+				offsetsToCommit.put(partition.getKafkaPartitionHandle(), new OffsetAndMetadata(offsetToCommit));
+				partition.setCommittedOffset(offsetToCommit);
+			}
+		}
+
+		// record the work to be committed by the main consumer thread and make sure the consumer notices that
+		consumerThread.setOffsetsToCommit(offsetsToCommit, commitCallback);
+	}
+
+	private String getFetcherName() {
+		return "Kafka Shuffle Fetcher";
+	}
+
+	private abstract static class KafkaShuffleElement<T> {
+
+		public boolean isRecord() {
+			return getClass() == KafkaShuffleRecord.class;
+		}
+
+		boolean isWatermark() {
+			return getClass() == KafkaShuffleWatermark.class;
+		}
+
+		KafkaShuffleRecord<T> asRecord() {
+			return (KafkaShuffleRecord<T>) this;
+		}
+
+		KafkaShuffleWatermark asWatermark() {
+			return (KafkaShuffleWatermark) this;
+		}
+	}
+
+	private static class KafkaShuffleWatermark<T> extends KafkaShuffleElement<T> {
+		final int subtask;
+		final long watermark;
+
+		KafkaShuffleWatermark(int subtask, long watermark) {
+			this.subtask = subtask;
+			this.watermark = watermark;
+		}
+	}
+
+	private static class KafkaShuffleRecord<T> extends KafkaShuffleElement<T> {
+		final T value;
+		final Long timestamp;
+
+		KafkaShuffleRecord(T value) {
+			this.value = value;
+			this.timestamp = null;
+		}
+
+		KafkaShuffleRecord(long timestamp, T value) {
+			this.value = value;
+			this.timestamp = timestamp;
+		}
+	}
+
+	private static class KafkaShuffleElementDeserializer<T> implements Serializable {
+		private transient DataInputDeserializer dis;
+
+		KafkaShuffleElementDeserializer() {
+			this.dis = new DataInputDeserializer();
+		}
+
+		KafkaShuffleElement<T> deserialize(TypeSerializer<T> serializer, ConsumerRecord<byte[], byte[]> record)
+			throws Exception {
+			byte[] value = record.value();
+			dis.setBuffer(value);
+			int tag = IntSerializer.INSTANCE.deserialize(dis);
+
+			if (tag == TAG_REC_WITHOUT_TIMESTAMP) {
+				return new KafkaShuffleRecord<>(serializer.deserialize(dis));
+			} else if (tag == TAG_REC_WITH_TIMESTAMP) {
+				return new KafkaShuffleRecord<>(LongSerializer.INSTANCE.deserialize(dis), serializer.deserialize(dis));
+			} else if (tag == TAG_WATERMARK) {
+				return new KafkaShuffleWatermark<>(
+					IntSerializer.INSTANCE.deserialize(dis), LongSerializer.INSTANCE.deserialize(dis));
+			}
+
+			throw new UnsupportedOperationException("Unsupported tag format");
+		}
+	}
+
+	/**
+	 * WatermarkHandler to generate watermarks.
+	 */
+	private static class WatermarkHandler {
+		private final int producerParallelism;
+		private final Map<Integer, Long> subtaskWatermark;
+
+		private long currentMinWatermark = Long.MIN_VALUE;
+
+		WatermarkHandler(int numberOfSubtask) {
+			this.producerParallelism = numberOfSubtask;
+			this.subtaskWatermark = new HashMap<>(numberOfSubtask);
+		}
+
+		public Optional<Watermark> checkAndGetNewWatermark(KafkaShuffleWatermark newWatermark) {
+			// watermarks is incremental for the same partition and PRODUCER subtask
+			Long currentSubTaskWatermark = subtaskWatermark.get(newWatermark.subtask);
+
+			Preconditions.checkState(
+				(currentSubTaskWatermark == null) || (currentSubTaskWatermark <= newWatermark.watermark),
+				"Watermark should always increase");
+
+			subtaskWatermark.put(newWatermark.subtask, newWatermark.watermark);
+
+			if (subtaskWatermark.values().size() < producerParallelism) {
+				return Optional.empty();
+			}

Review comment:
       What happens if one partition has ended and we receive no watermarks anymore? Are the watermarks of the other partitions still propagated properly? Almost feels like using `StatusWatermarkValve` would be handy.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org