You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2020/08/12 06:39:18 UTC

[flink] branch master updated: [FLINK-18483][kinesis] Test coverage improvements for FlinkKinesisConsumer/ShardConsumer

This is an automated email from the ASF dual-hosted git repository.

tzulitai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 9064e1c  [FLINK-18483][kinesis] Test coverage improvements for FlinkKinesisConsumer/ShardConsumer
9064e1c is described below

commit 9064e1cd1cff0b7adc7a1e3c49606fe14b07e24f
Author: Danny Cranmer <cr...@amazon.com>
AuthorDate: Mon Jun 29 14:24:12 2020 +0100

    [FLINK-18483][kinesis] Test coverage improvements for FlinkKinesisConsumer/ShardConsumer
    
    This closes #12850.
---
 flink-connectors/flink-connector-kinesis/pom.xml   |   7 +
 .../kinesis/internals/ShardConsumerTest.java       | 228 ++++++++++-----------
 .../testutils/FakeKinesisBehavioursFactory.java    | 177 ++++++++++------
 .../connectors/kinesis/testutils/TestUtils.java    |  57 ++++++
 4 files changed, 287 insertions(+), 182 deletions(-)

diff --git a/flink-connectors/flink-connector-kinesis/pom.xml b/flink-connectors/flink-connector-kinesis/pom.xml
index baa067d..8e0bf7c 100644
--- a/flink-connectors/flink-connector-kinesis/pom.xml
+++ b/flink-connectors/flink-connector-kinesis/pom.xml
@@ -97,6 +97,13 @@ under the License.
 
 		<dependency>
 			<groupId>com.amazonaws</groupId>
+			<artifactId>amazon-kinesis-aggregator</artifactId>
+			<version>1.0.3</version>
+			<scope>test</scope>
+		</dependency>
+
+		<dependency>
+			<groupId>com.amazonaws</groupId>
 			<artifactId>aws-java-sdk-kinesis</artifactId>
 			<version>${aws.sdk.version}</version>
 		</dependency>
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java
index f2552ee..6fbe4ee 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java
@@ -20,7 +20,6 @@ package org.apache.flink.streaming.connectors.kinesis.internals;
 import org.apache.flink.api.common.serialization.SimpleStringSchema;
 import org.apache.flink.streaming.connectors.kinesis.metrics.ShardMetricsReporter;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
-import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle;
 import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface;
@@ -37,12 +36,24 @@ import org.junit.Test;
 import org.mockito.Mockito;
 
 import java.math.BigInteger;
+import java.text.SimpleDateFormat;
 import java.util.Collections;
+import java.util.Date;
 import java.util.LinkedList;
 import java.util.Properties;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.SHARD_USE_ADAPTIVE_READS;
+import static org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.STREAM_INITIAL_TIMESTAMP;
+import static org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.STREAM_TIMESTAMP_DATE_FORMAT;
+import static org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber.SENTINEL_AT_TIMESTAMP_SEQUENCE_NUM;
+import static org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM;
+import static org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM;
 import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
 
 /**
  * Tests for the {@link ShardConsumer}.
@@ -51,148 +62,123 @@ public class ShardConsumerTest {
 
 	@Test
 	public void testMetricsReporting() {
-		StreamShardHandle fakeToBeConsumedShard = getMockStreamShard("fakeStream", 0);
-
-		LinkedList<KinesisStreamShardState> subscribedShardsStateUnderTest = new LinkedList<>();
-		subscribedShardsStateUnderTest.add(
-			new KinesisStreamShardState(
-				KinesisDataFetcher.convertToStreamShardMetadata(fakeToBeConsumedShard),
-				fakeToBeConsumedShard,
-				new SequenceNumber("fakeStartingState")));
+		KinesisProxyInterface kinesis = FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(500, 5, 500);
 
-		TestSourceContext<String> sourceContext = new TestSourceContext<>();
-
-		KinesisDeserializationSchemaWrapper<String> deserializationSchema = new KinesisDeserializationSchemaWrapper<>(
-			new SimpleStringSchema());
-		TestableKinesisDataFetcher<String> fetcher =
-			new TestableKinesisDataFetcher<>(
-				Collections.singletonList("fakeStream"),
-				sourceContext,
-				new Properties(),
-				deserializationSchema,
-				10,
-				2,
-				new AtomicReference<>(),
-				subscribedShardsStateUnderTest,
-				KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(Collections.singletonList("fakeStream")),
-				Mockito.mock(KinesisProxyInterface.class));
+		ShardMetricsReporter metrics = assertNumberOfMessagesReceivedFromKinesis(500, kinesis, fakeSequenceNumber());
+		assertEquals(500, metrics.getMillisBehindLatest());
+		assertEquals(10000, metrics.getMaxNumberOfRecordsPerFetch());
+	}
 
-		ShardMetricsReporter shardMetricsReporter = new ShardMetricsReporter();
-		long millisBehindLatest = 500L;
-		new ShardConsumer<>(
-			fetcher,
-			0,
-			subscribedShardsStateUnderTest.get(0).getStreamShardHandle(),
-			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(),
-			FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(1000, 9, millisBehindLatest),
-			shardMetricsReporter,
-			deserializationSchema)
-			.run();
+	@Test
+	public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithStartingSequenceNumber() throws Exception {
+		KinesisProxyInterface kinesis = spy(FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(1000, 9, 500L));
 
-		// the millisBehindLatest metric should have been reported
-		assertEquals(millisBehindLatest, shardMetricsReporter.getMillisBehindLatest());
+		assertNumberOfMessagesReceivedFromKinesis(1000, kinesis, fakeSequenceNumber());
+		verify(kinesis).getShardIterator(any(), eq("AFTER_SEQUENCE_NUMBER"), eq("fakeStartingState"));
 	}
 
 	@Test
-	public void testCorrectNumOfCollectedRecordsAndUpdatedState() {
-		StreamShardHandle fakeToBeConsumedShard = getMockStreamShard("fakeStream", 0);
+	public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithStartingSequenceSentinelTimestamp() throws Exception {
+		String format = "yyyy-MM-dd'T'HH:mm";
+		String timestamp = "2020-07-02T09:14";
+		Date expectedTimestamp = new SimpleDateFormat(format).parse(timestamp);
 
-		LinkedList<KinesisStreamShardState> subscribedShardsStateUnderTest = new LinkedList<>();
-		subscribedShardsStateUnderTest.add(
-			new KinesisStreamShardState(KinesisDataFetcher.convertToStreamShardMetadata(fakeToBeConsumedShard),
-				fakeToBeConsumedShard, new SequenceNumber("fakeStartingState")));
+		Properties consumerProperties = new Properties();
+		consumerProperties.setProperty(STREAM_INITIAL_TIMESTAMP, timestamp);
+		consumerProperties.setProperty(STREAM_TIMESTAMP_DATE_FORMAT, format);
+		SequenceNumber sequenceNumber = SENTINEL_AT_TIMESTAMP_SEQUENCE_NUM.get();
 
-		TestSourceContext<String> sourceContext = new TestSourceContext<>();
+		KinesisProxyInterface kinesis = spy(FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(10, 1, 0));
 
-		KinesisDeserializationSchemaWrapper<String> deserializationSchema = new KinesisDeserializationSchemaWrapper<>(
-			new SimpleStringSchema());
-		TestableKinesisDataFetcher<String> fetcher =
-			new TestableKinesisDataFetcher<>(
-				Collections.singletonList("fakeStream"),
-				sourceContext,
-				new Properties(),
-				deserializationSchema,
-				10,
-				2,
-				new AtomicReference<>(),
-				subscribedShardsStateUnderTest,
-				KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(Collections.singletonList("fakeStream")),
-				Mockito.mock(KinesisProxyInterface.class));
+		assertNumberOfMessagesReceivedFromKinesis(10, kinesis, sequenceNumber, consumerProperties);
+		verify(kinesis).getShardIterator(any(), eq("AT_TIMESTAMP"), eq(expectedTimestamp));
+	}
 
-		int shardIndex = fetcher.registerNewSubscribedShardState(subscribedShardsStateUnderTest.get(0));
-		new ShardConsumer<>(
-			fetcher,
-			shardIndex,
-			subscribedShardsStateUnderTest.get(0).getStreamShardHandle(),
-			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(),
-			FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(1000, 9, 500L),
-			new ShardMetricsReporter(),
-			deserializationSchema)
-			.run();
+	@Test
+	public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithStartingSequenceSentinelEarliest() throws Exception {
+		SequenceNumber sequenceNumber = SENTINEL_EARLIEST_SEQUENCE_NUM.get();
 
-		assertEquals(1000, sourceContext.getCollectedOutputs().size());
-		assertEquals(
-			SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get(),
-			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum());
+		KinesisProxyInterface kinesis = spy(FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(50, 2, 0));
+
+		assertNumberOfMessagesReceivedFromKinesis(50, kinesis, sequenceNumber);
+		verify(kinesis).getShardIterator(any(), eq("TRIM_HORIZON"), eq(null));
 	}
 
 	@Test
 	public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithUnexpectedExpiredIterator() {
-		StreamShardHandle fakeToBeConsumedShard = getMockStreamShard("fakeStream", 0);
+		KinesisProxyInterface kinesis = FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCallsWithUnexpectedExpiredIterator(1000, 9, 7, 500L);
 
-		LinkedList<KinesisStreamShardState> subscribedShardsStateUnderTest = new LinkedList<>();
-		subscribedShardsStateUnderTest.add(
-			new KinesisStreamShardState(KinesisDataFetcher.convertToStreamShardMetadata(fakeToBeConsumedShard),
-				fakeToBeConsumedShard, new SequenceNumber("fakeStartingState")));
+		// Get a total of 1000 records with 9 getRecords() calls,
+		// and the 7th getRecords() call will encounter an unexpected expired shard iterator
+		assertNumberOfMessagesReceivedFromKinesis(1000, kinesis, fakeSequenceNumber());
+	}
 
-		TestSourceContext<String> sourceContext = new TestSourceContext<>();
+	@Test
+	public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithAdaptiveReads() {
+		Properties consumerProperties = new Properties();
+		consumerProperties.setProperty(SHARD_USE_ADAPTIVE_READS, "true");
 
-		KinesisDeserializationSchemaWrapper<String> deserializationSchema = new KinesisDeserializationSchemaWrapper<>(
-			new SimpleStringSchema());
-		TestableKinesisDataFetcher<String> fetcher =
-			new TestableKinesisDataFetcher<>(
-				Collections.singletonList("fakeStream"),
-				sourceContext,
-				new Properties(),
-				deserializationSchema,
-				10,
-				2,
-				new AtomicReference<>(),
-				subscribedShardsStateUnderTest,
-				KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(Collections.singletonList("fakeStream")),
-				Mockito.mock(KinesisProxyInterface.class));
+		KinesisProxyInterface kinesis = FakeKinesisBehavioursFactory.initialNumOfRecordsAfterNumOfGetRecordsCallsWithAdaptiveReads(10, 2, 500L);
 
-		int shardIndex = fetcher.registerNewSubscribedShardState(subscribedShardsStateUnderTest.get(0));
-		new ShardConsumer<>(
-			fetcher,
-			shardIndex,
-			subscribedShardsStateUnderTest.get(0).getStreamShardHandle(),
-			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(),
-			// Get a total of 1000 records with 9 getRecords() calls,
-			// and the 7th getRecords() call will encounter an unexpected expired shard iterator
-			FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCallsWithUnexpectedExpiredIterator(
-				1000, 9, 7, 500L),
-			new ShardMetricsReporter(),
-			deserializationSchema)
-			.run();
+		// Avg record size for first batch --> 10 * 10 Kb/10 = 10 Kb
+		// Number of records fetched in second batch --> 2 Mb/10Kb * 5 = 40
+		// Total number of records = 10 + 40 = 50
+		assertNumberOfMessagesReceivedFromKinesis(50, kinesis, fakeSequenceNumber(), consumerProperties);
+	}
 
-		assertEquals(1000, sourceContext.getCollectedOutputs().size());
-		assertEquals(
-			SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get(),
-			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum());
+	@Test
+	public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithAggregatedRecords() throws Exception {
+		KinesisProxyInterface kinesis = spy(FakeKinesisBehavioursFactory.aggregatedRecords(3, 5, 10));
+
+		// Expecting to receive all messages
+		// 10 batches of 3 aggregated records each with 5 child records
+		// 10 * 3 * 5 = 150
+		ShardMetricsReporter metrics = assertNumberOfMessagesReceivedFromKinesis(150, kinesis, fakeSequenceNumber());
+		assertEquals(3, metrics.getNumberOfAggregatedRecords());
+		assertEquals(15, metrics.getNumberOfDeaggregatedRecords());
+
+		verify(kinesis).getShardIterator(any(), eq("AFTER_SEQUENCE_NUMBER"), eq("fakeStartingState"));
 	}
 
 	@Test
-	public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithAdaptiveReads() {
-		Properties consumerProperties = new Properties();
-		consumerProperties.put("flink.shard.adaptivereads", "true");
+	public void testCorrectNumOfCollectedRecordsAndUpdatedStateWithAggregatedRecordsWithSubSequenceStartingNumber() throws Exception {
+		SequenceNumber sequenceNumber = new SequenceNumber("0", 5);
+		KinesisProxyInterface kinesis = spy(FakeKinesisBehavioursFactory.aggregatedRecords(1, 10, 5));
+
+		// Expecting to start consuming from last sub sequence number
+		// 5 batches of 1 aggregated record each with 10 child records
+		// Last consumed message was sub-sequence 5 (6/10) (zero based) (remaining are 6, 7, 8, 9)
+		// 5 * 1 * 10 - 6 = 44
+		ShardMetricsReporter metrics = assertNumberOfMessagesReceivedFromKinesis(44, kinesis, sequenceNumber);
+		assertEquals(1, metrics.getNumberOfAggregatedRecords());
+		assertEquals(10, metrics.getNumberOfDeaggregatedRecords());
+
+		verify(kinesis).getShardIterator(any(), eq("AT_SEQUENCE_NUMBER"), eq("0"));
+	}
+
+	private SequenceNumber fakeSequenceNumber() {
+		return new SequenceNumber("fakeStartingState");
+	}
+
+	private ShardMetricsReporter assertNumberOfMessagesReceivedFromKinesis(
+			final int expectedNumberOfMessages,
+			final KinesisProxyInterface kinesis,
+			final SequenceNumber startingSequenceNumber) {
+		return assertNumberOfMessagesReceivedFromKinesis(expectedNumberOfMessages, kinesis, startingSequenceNumber, new Properties());
+	}
 
+	private ShardMetricsReporter assertNumberOfMessagesReceivedFromKinesis(
+			final int expectedNumberOfMessages,
+			final KinesisProxyInterface kinesis,
+			final SequenceNumber startingSequenceNumber,
+			final Properties consumerProperties) {
+		ShardMetricsReporter shardMetricsReporter = new ShardMetricsReporter();
 		StreamShardHandle fakeToBeConsumedShard = getMockStreamShard("fakeStream", 0);
 
 		LinkedList<KinesisStreamShardState> subscribedShardsStateUnderTest = new LinkedList<>();
 		subscribedShardsStateUnderTest.add(
 			new KinesisStreamShardState(KinesisDataFetcher.convertToStreamShardMetadata(fakeToBeConsumedShard),
-				fakeToBeConsumedShard, new SequenceNumber("fakeStartingState")));
+				fakeToBeConsumedShard, startingSequenceNumber));
 
 		TestSourceContext<String> sourceContext = new TestSourceContext<>();
 
@@ -217,19 +203,17 @@ public class ShardConsumerTest {
 			shardIndex,
 			subscribedShardsStateUnderTest.get(0).getStreamShardHandle(),
 			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(),
-			// Initial number of records to fetch --> 10
-			FakeKinesisBehavioursFactory.initialNumOfRecordsAfterNumOfGetRecordsCallsWithAdaptiveReads(10, 2, 500L),
-			new ShardMetricsReporter(),
+			kinesis,
+			shardMetricsReporter,
 			deserializationSchema)
 			.run();
 
-		// Avg record size for first batch --> 10 * 10 Kb/10 = 10 Kb
-		// Number of records fetched in second batch --> 2 Mb/10Kb * 5 = 40
-		// Total number of records = 10 + 40 = 50
-		assertEquals(50, sourceContext.getCollectedOutputs().size());
+		assertEquals(expectedNumberOfMessages, sourceContext.getCollectedOutputs().size());
 		assertEquals(
-			SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get(),
+			SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get(),
 			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum());
+
+		return shardMetricsReporter;
 	}
 
 	private static StreamShardHandle getMockStreamShard(String streamName, int shardId) {
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/FakeKinesisBehavioursFactory.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/FakeKinesisBehavioursFactory.java
index ee4e0a3..177a7cf 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/FakeKinesisBehavioursFactory.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/FakeKinesisBehavioursFactory.java
@@ -41,6 +41,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 
@@ -107,6 +108,25 @@ public class FakeKinesisBehavioursFactory {
 				millisBehindLatest);
 	}
 
+	/**
+	 * Creates a mocked Kinesis Proxy that will Emit aggregated records from a fake stream:
+	 * - There will be {@code numOfGetRecordsCalls} batches available in the stream
+	 * - Each batch will contain {@code numOfAggregatedRecords} aggregated records
+	 * - Each aggregated record will contain {@code numOfChildRecords} child records
+	 * Therefore this class will emit a total of
+	 * {@code numOfGetRecordsCalls * numOfAggregatedRecords * numOfChildRecords} records.
+	 *
+	 * @param numOfAggregatedRecords the number of records per batch
+	 * @param numOfChildRecords the number of child records in each aggregated record
+	 * @param numOfGetRecordsCalls the number batches available in the fake stream
+	 */
+	public static KinesisProxyInterface aggregatedRecords(
+			final int numOfAggregatedRecords,
+			final int numOfChildRecords,
+			final int numOfGetRecordsCalls) {
+			return new SingleShardEmittingAggregatedRecordsKinesis(numOfAggregatedRecords, numOfChildRecords, numOfGetRecordsCalls);
+	}
+
 	public static KinesisProxyInterface blockingQueueGetRecords(Map<String, List<BlockingQueue<String>>> streamsToShardQueues) {
 		return new BlockingQueueKinesis(streamsToShardQueues);
 	}
@@ -133,7 +153,7 @@ public class FakeKinesisBehavioursFactory {
 
 		@Override
 		public GetRecordsResult getRecords(String shardIterator, int maxRecordsToGet) {
-			if ((Integer.valueOf(shardIterator) == orderOfCallToExpire - 1) && !expiredOnceAlready) {
+			if ((Integer.parseInt(shardIterator) == orderOfCallToExpire - 1) && !expiredOnceAlready) {
 				// we fake only once the expired iterator exception at the specified get records attempt order
 				expiredOnceAlready = true;
 				throw new ExpiredIteratorException("Artificial expired shard iterator");
@@ -147,8 +167,8 @@ public class FakeKinesisBehavioursFactory {
 					.withRecords(shardItrToRecordBatch.get(shardIterator))
 					.withMillisBehindLatest(millisBehindLatest)
 					.withNextShardIterator(
-						(Integer.valueOf(shardIterator) == totalNumOfGetRecordsCalls - 1)
-							? null : String.valueOf(Integer.valueOf(shardIterator) + 1)); // last next shard iterator is null
+						(Integer.parseInt(shardIterator) == totalNumOfGetRecordsCalls - 1)
+							? null : String.valueOf(Integer.parseInt(shardIterator) + 1)); // last next shard iterator is null
 			}
 		}
 
@@ -214,8 +234,8 @@ public class FakeKinesisBehavioursFactory {
 				.withRecords(shardItrToRecordBatch.get(shardIterator))
 				.withMillisBehindLatest(millisBehindLatest)
 				.withNextShardIterator(
-					(Integer.valueOf(shardIterator) == totalNumOfGetRecordsCalls - 1)
-						? null : String.valueOf(Integer.valueOf(shardIterator) + 1)); // last next shard iterator is null
+					(Integer.parseInt(shardIterator) == totalNumOfGetRecordsCalls - 1)
+						? null : String.valueOf(Integer.parseInt(shardIterator) + 1)); // last next shard iterator is null
 		}
 
 		@Override
@@ -245,73 +265,43 @@ public class FakeKinesisBehavioursFactory {
 
 	}
 
-	private static class SingleShardEmittingAdaptiveNumOfRecordsKinesis implements
-			KinesisProxyInterface {
-
-		protected final int totalNumOfGetRecordsCalls;
-
-		protected final int totalNumOfRecords;
-
-		private final long millisBehindLatest;
-
-		protected final Map<String, List<Record>> shardItrToRecordBatch;
+	private static class SingleShardEmittingAdaptiveNumOfRecordsKinesis extends SingleShardEmittingKinesis {
 
-		protected static long averageRecordSizeBytes;
+		protected static long averageRecordSizeBytes = 0L;
 
 		private static final long KINESIS_SHARD_BYTES_PER_SECOND_LIMIT = 2 * 1024L * 1024L;
 
-		public SingleShardEmittingAdaptiveNumOfRecordsKinesis(final int numOfRecords,
+		public SingleShardEmittingAdaptiveNumOfRecordsKinesis(
+				final int numOfRecords,
 				final int numOfGetRecordsCalls,
 				final long millisBehindLatest) {
-			this.totalNumOfRecords = numOfRecords;
-			this.totalNumOfGetRecordsCalls = numOfGetRecordsCalls;
-			this.millisBehindLatest = millisBehindLatest;
-			this.averageRecordSizeBytes = 0L;
+			super(initShardItrToRecordBatch(numOfRecords, numOfGetRecordsCalls), millisBehindLatest);
+		}
 
+		private static Map<String, List<Record>> initShardItrToRecordBatch(
+				final int numOfRecords,
+				final int numOfGetRecordsCalls) {
 			// initialize the record batches that we will be fetched
-			this.shardItrToRecordBatch = new HashMap<>();
+			Map<String, List<Record>> shardItrToRecordBatch = new HashMap<>();
 
 			int numOfAlreadyPartitionedRecords = 0;
 			int numOfRecordsPerBatch = numOfRecords;
-			for (int batch = 0; batch < totalNumOfGetRecordsCalls; batch++) {
-					shardItrToRecordBatch.put(
-							String.valueOf(batch),
-							createRecordBatchWithRange(
-									numOfAlreadyPartitionedRecords,
-									numOfAlreadyPartitionedRecords + numOfRecordsPerBatch));
-					numOfAlreadyPartitionedRecords += numOfRecordsPerBatch;
+			for (int batch = 0; batch < numOfGetRecordsCalls; batch++) {
+				shardItrToRecordBatch.put(
+					String.valueOf(batch),
+					createRecordBatchWithRange(
+						numOfAlreadyPartitionedRecords,
+						numOfAlreadyPartitionedRecords + numOfRecordsPerBatch));
+				numOfAlreadyPartitionedRecords += numOfRecordsPerBatch;
 
 				numOfRecordsPerBatch = (int) (KINESIS_SHARD_BYTES_PER_SECOND_LIMIT /
-						(averageRecordSizeBytes * 1000L / ConsumerConfigConstants.DEFAULT_SHARD_GETRECORDS_INTERVAL_MILLIS));
+					(averageRecordSizeBytes * 1000L / ConsumerConfigConstants.DEFAULT_SHARD_GETRECORDS_INTERVAL_MILLIS));
 			}
-		}
-
-		@Override
-		public GetRecordsResult getRecords(String shardIterator, int maxRecordsToGet) {
-			// assuming that the maxRecordsToGet is always large enough
-			return new GetRecordsResult()
-					.withRecords(shardItrToRecordBatch.get(shardIterator))
-					.withMillisBehindLatest(millisBehindLatest)
-					.withNextShardIterator(
-							(Integer.valueOf(shardIterator) == totalNumOfGetRecordsCalls - 1)
-									? null : String
-									.valueOf(Integer.valueOf(shardIterator) + 1)); // last next shard iterator is null
-		}
 
-		@Override
-		public String getShardIterator(StreamShardHandle shard, String shardIteratorType,
-				Object startingMarker) {
-			// this will be called only one time per ShardConsumer;
-			// so, simply return the iterator of the first batch of records
-			return "0";
-		}
-
-		@Override
-		public GetShardListResult getShardList(Map<String, String> streamNamesWithLastSeenShardIds) {
-			return null;
+			return shardItrToRecordBatch;
 		}
 
-		public static List<Record> createRecordBatchWithRange(int min, int max) {
+		private static List<Record> createRecordBatchWithRange(int min, int max) {
 			List<Record> batch = new LinkedList<>();
 			long	sumRecordBatchBytes = 0L;
 			// Create record of size 10Kb
@@ -320,7 +310,7 @@ public class FakeKinesisBehavioursFactory {
 			for (int i = min; i < max; i++) {
 				Record record = new Record()
 								.withData(
-										ByteBuffer.wrap(String.valueOf(data).getBytes(ConfigConstants.DEFAULT_CHARSET)))
+										ByteBuffer.wrap(data.getBytes(ConfigConstants.DEFAULT_CHARSET)))
 								.withPartitionKey(UUID.randomUUID().toString())
 								.withApproximateArrivalTimestamp(new Date(System.currentTimeMillis()))
 								.withSequenceNumber(String.valueOf(i));
@@ -335,17 +325,84 @@ public class FakeKinesisBehavioursFactory {
 			return batch;
 		}
 
-		private static String createDataSize(long msgSize) {
+		private static String createDataSize(final long msgSize) {
 			char[] data = new char[(int) msgSize];
 			return new String(data);
+		}
+	}
+
+	private static class SingleShardEmittingAggregatedRecordsKinesis extends SingleShardEmittingKinesis {
+
+		public SingleShardEmittingAggregatedRecordsKinesis(
+				final int numOfAggregatedRecords,
+				final int numOfChildRecords,
+				final int numOfGetRecordsCalls) {
+			super(initShardItrToRecordBatch(numOfAggregatedRecords, numOfChildRecords, numOfGetRecordsCalls));
+		}
+
+		private static Map<String, List<Record>> initShardItrToRecordBatch(final int numOfAggregatedRecords,
+				final int numOfChildRecords,
+				final int numOfGetRecordsCalls) {
+
+			Map<String, List<Record>> shardToRecordBatch = new HashMap<>();
+
+			AtomicInteger sequenceNumber = new AtomicInteger();
+			for (int batch = 0; batch < numOfGetRecordsCalls; batch++) {
+				List<Record> recordBatch = TestUtils.createAggregatedRecordBatch(
+					numOfAggregatedRecords, numOfChildRecords, sequenceNumber);
+
+				shardToRecordBatch.put(String.valueOf(batch), recordBatch);
+			}
 
+			return shardToRecordBatch;
+		}
+	}
+
+	/** A helper base class used to emit records from a single sharded fake Kinesis Stream. */
+	private abstract static class SingleShardEmittingKinesis implements KinesisProxyInterface {
+
+		private final long millisBehindLatest;
+
+		private final Map<String, List<Record>> shardItrToRecordBatch;
+
+		protected SingleShardEmittingKinesis(final Map<String, List<Record>> shardItrToRecordBatch) {
+			this(shardItrToRecordBatch, 0L);
+		}
+
+		protected SingleShardEmittingKinesis(final Map<String, List<Record>> shardItrToRecordBatch, final long millisBehindLatest) {
+			this.millisBehindLatest = millisBehindLatest;
+			this.shardItrToRecordBatch = shardItrToRecordBatch;
 		}
 
+		@Override
+		public GetRecordsResult getRecords(String shardIterator, int maxRecordsToGet) {
+			int index = Integer.parseInt(shardIterator);
+			// last next shard iterator is null
+			String nextShardIterator = (index == shardItrToRecordBatch.size() - 1) ? null : String.valueOf(index + 1);
+
+			// assuming that the maxRecordsToGet is always large enough
+			return new GetRecordsResult()
+				.withRecords(shardItrToRecordBatch.get(shardIterator))
+				.withNextShardIterator(nextShardIterator)
+				.withMillisBehindLatest(millisBehindLatest);
+		}
+
+		@Override
+		public String getShardIterator(StreamShardHandle shard, String shardIteratorType, Object startingMarker) {
+			// this will be called only one time per ShardConsumer;
+			// so, simply return the iterator of the first batch of records
+			return "0";
+		}
+
+		@Override
+		public GetShardListResult getShardList(Map<String, String> streamNamesWithLastSeenShardIds) {
+			return null;
+		}
 	}
 
 	private static class NonReshardedStreamsKinesis implements KinesisProxyInterface {
 
-		private Map<String, List<StreamShardHandle>> streamsWithListOfShards = new HashMap<>();
+		private final Map<String, List<StreamShardHandle>> streamsWithListOfShards = new HashMap<>();
 
 		public NonReshardedStreamsKinesis(Map<String, Integer> streamsToShardCount) {
 			for (Map.Entry<String, Integer> streamToShardCount : streamsToShardCount.entrySet()) {
@@ -436,8 +493,8 @@ public class FakeKinesisBehavioursFactory {
 
 	private static class BlockingQueueKinesis implements KinesisProxyInterface {
 
-		private Map<String, List<StreamShardHandle>> streamsWithListOfShards = new HashMap<>();
-		private Map<String, BlockingQueue<String>> shardIteratorToQueueMap = new HashMap<>();
+		private final Map<String, List<StreamShardHandle>> streamsWithListOfShards = new HashMap<>();
+		private final Map<String, BlockingQueue<String>> shardIteratorToQueueMap = new HashMap<>();
 
 		private static String getShardIterator(StreamShardHandle shardHandle) {
 			return shardHandle.getStreamName() + "-" + shardHandle.getShard().getShardId();
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestUtils.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestUtils.java
index f6d0a44..7fce762 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestUtils.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestUtils.java
@@ -17,9 +17,21 @@
 
 package org.apache.flink.streaming.connectors.kinesis.testutils;
 
+import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants;
 
+import com.amazonaws.kinesis.agg.AggRecord;
+import com.amazonaws.kinesis.agg.RecordAggregator;
+import com.amazonaws.services.kinesis.model.Record;
+import org.apache.commons.lang3.RandomStringUtils;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.List;
 import java.util.Properties;
+import java.util.UUID;
+import java.util.concurrent.atomic.AtomicInteger;
 
 /**
  * General test utils.
@@ -36,4 +48,49 @@ public class TestUtils {
 
 		return config;
 	}
+
+	/**
+	 * Creates a batch of {@code numOfAggregatedRecords} aggregated records.
+	 * Each aggregated record contains {@code numOfChildRecords} child records.
+	 * Each record is assigned the sequence number: {@code sequenceNumber + index * numOfChildRecords}.
+	 * The next sequence number is output to the {@code sequenceNumber}.
+	 *
+	 * @param numOfAggregatedRecords the number of records in the batch
+	 * @param numOfChildRecords the number of child records for each aggregated record
+	 * @param sequenceNumber the starting sequence number, outputs the next sequence number
+	 * @return the batch af aggregated records
+	 */
+	public static List<Record> createAggregatedRecordBatch(
+			final int numOfAggregatedRecords,
+			final int numOfChildRecords,
+			final AtomicInteger sequenceNumber) {
+		List<Record> recordBatch = new ArrayList<>();
+		RecordAggregator recordAggregator = new RecordAggregator();
+
+		for (int record = 0; record < numOfAggregatedRecords; record++) {
+			String partitionKey = UUID.randomUUID().toString();
+
+			for (int child = 0; child < numOfChildRecords; child++) {
+				byte[] data = RandomStringUtils.randomAlphabetic(1024)
+					.getBytes(ConfigConstants.DEFAULT_CHARSET);
+
+				try {
+					recordAggregator.addUserRecord(partitionKey, data);
+				} catch (Exception e) {
+					throw new IllegalStateException("Error aggregating message", e);
+				}
+			}
+
+			AggRecord aggRecord = recordAggregator.clearAndGet();
+
+			recordBatch.add(new Record()
+				.withData(ByteBuffer.wrap(aggRecord.toRecordBytes()))
+				.withPartitionKey(partitionKey)
+				.withApproximateArrivalTimestamp(new Date(System.currentTimeMillis()))
+				.withSequenceNumber(String.valueOf(sequenceNumber.getAndAdd(numOfChildRecords))));
+		}
+
+		return recordBatch;
+	}
+
 }