You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2015/10/26 05:18:39 UTC

spark git commit: [SPARK-10891][STREAMING][KINESIS] Add MessageHandler to KinesisUtils.createStream similar to Direct Kafka

Repository: spark
Updated Branches:
  refs/heads/master 80279ac18 -> 63accc796


[SPARK-10891][STREAMING][KINESIS] Add MessageHandler to KinesisUtils.createStream similar to Direct Kafka

This PR allows users to map a Kinesis `Record` to a generic `T` when creating a Kinesis stream. This is particularly useful, if you would like to do extra work with Kinesis metadata such as sequence number, and partition key.

TODO:
 - [x] add tests

Author: Burak Yavuz <br...@gmail.com>

Closes #8954 from brkyvz/kinesis-handler.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/63accc79
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/63accc79
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/63accc79

Branch: refs/heads/master
Commit: 63accc79625d8a03d0624717af5e1d81b18a6da3
Parents: 80279ac
Author: Burak Yavuz <br...@gmail.com>
Authored: Sun Oct 25 21:18:35 2015 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Sun Oct 25 21:18:35 2015 -0700

----------------------------------------------------------------------
 .../kinesis/KinesisBackedBlockRDD.scala         |  35 ++-
 .../streaming/kinesis/KinesisInputDStream.scala |  15 +-
 .../streaming/kinesis/KinesisReceiver.scala     |  18 +-
 .../kinesis/KinesisRecordProcessor.scala        |   4 +-
 .../spark/streaming/kinesis/KinesisUtils.scala  | 247 +++++++++++++++++--
 .../kinesis/JavaKinesisStreamSuite.java         |  29 ++-
 .../kinesis/KinesisBackedBlockRDDSuite.scala    |  16 +-
 .../kinesis/KinesisReceiverSuite.scala          |   4 +-
 .../streaming/kinesis/KinesisStreamSuite.scala  |  44 +++-
 9 files changed, 337 insertions(+), 75 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
index 5d32fa6..000897a 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.streaming.kinesis
 
 import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
 import scala.util.control.NonFatal
 
 import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
@@ -67,7 +68,7 @@ class KinesisBackedBlockRDDPartition(
  * sequence numbers of the corresponding blocks.
  */
 private[kinesis]
-class KinesisBackedBlockRDD(
+class KinesisBackedBlockRDD[T: ClassTag](
     @transient sc: SparkContext,
     val regionName: String,
     val endpointUrl: String,
@@ -75,8 +76,9 @@ class KinesisBackedBlockRDD(
     @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges],
     @transient isBlockIdValid: Array[Boolean] = Array.empty,
     val retryTimeoutMs: Int = 10000,
+    val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _,
     val awsCredentialsOption: Option[SerializableAWSCredentials] = None
-  ) extends BlockRDD[Array[Byte]](sc, blockIds) {
+  ) extends BlockRDD[T](sc, blockIds) {
 
   require(blockIds.length == arrayOfseqNumberRanges.length,
     "Number of blockIds is not equal to the number of sequence number ranges")
@@ -90,23 +92,23 @@ class KinesisBackedBlockRDD(
     }
   }
 
-  override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
     val blockManager = SparkEnv.get.blockManager
     val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition]
     val blockId = partition.blockId
 
-    def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = {
+    def getBlockFromBlockManager(): Option[Iterator[T]] = {
       logDebug(s"Read partition data of $this from block manager, block $blockId")
-      blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]])
+      blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
     }
 
-    def getBlockFromKinesis(): Iterator[Array[Byte]] = {
-      val credenentials = awsCredentialsOption.getOrElse {
+    def getBlockFromKinesis(): Iterator[T] = {
+      val credentials = awsCredentialsOption.getOrElse {
         new DefaultAWSCredentialsProviderChain().getCredentials()
       }
       partition.seqNumberRanges.ranges.iterator.flatMap { range =>
-        new KinesisSequenceRangeIterator(
-          credenentials, endpointUrl, regionName, range, retryTimeoutMs)
+        new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
+          range, retryTimeoutMs).map(messageHandler)
       }
     }
     if (partition.isBlockIdValid) {
@@ -129,8 +131,7 @@ class KinesisSequenceRangeIterator(
     endpointUrl: String,
     regionId: String,
     range: SequenceNumberRange,
-    retryTimeoutMs: Int
-  ) extends NextIterator[Array[Byte]] with Logging {
+    retryTimeoutMs: Int) extends NextIterator[Record] with Logging {
 
   private val client = new AmazonKinesisClient(credentials)
   private val streamName = range.streamName
@@ -142,8 +143,8 @@ class KinesisSequenceRangeIterator(
 
   client.setEndpoint(endpointUrl, "kinesis", regionId)
 
-  override protected def getNext(): Array[Byte] = {
-    var nextBytes: Array[Byte] = null
+  override protected def getNext(): Record = {
+    var nextRecord: Record = null
     if (toSeqNumberReceived) {
       finished = true
     } else {
@@ -170,10 +171,7 @@ class KinesisSequenceRangeIterator(
       } else {
 
         // Get the record, copy the data into a byte array and remember its sequence number
-        val nextRecord: Record = internalIterator.next()
-        val byteBuffer = nextRecord.getData()
-        nextBytes = new Array[Byte](byteBuffer.remaining())
-        byteBuffer.get(nextBytes)
+        nextRecord = internalIterator.next()
         lastSeqNumber = nextRecord.getSequenceNumber()
 
         // If the this record's sequence number matches the stopping sequence number, then make sure
@@ -182,9 +180,8 @@ class KinesisSequenceRangeIterator(
           toSeqNumberReceived = true
         }
       }
-
     }
-    nextBytes
+    nextRecord
   }
 
   override protected def close(): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
index 2e4204d..72ab635 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.streaming.kinesis
 
+import scala.reflect.ClassTag
+
 import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import com.amazonaws.services.kinesis.model.Record
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.{BlockId, StorageLevel}
@@ -26,7 +29,7 @@ import org.apache.spark.streaming.receiver.Receiver
 import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
 import org.apache.spark.streaming.{Duration, StreamingContext, Time}
 
-private[kinesis] class KinesisInputDStream(
+private[kinesis] class KinesisInputDStream[T: ClassTag](
     @transient _ssc: StreamingContext,
     streamName: String,
     endpointUrl: String,
@@ -35,11 +38,12 @@ private[kinesis] class KinesisInputDStream(
     checkpointAppName: String,
     checkpointInterval: Duration,
     storageLevel: StorageLevel,
+    messageHandler: Record => T,
     awsCredentialsOption: Option[SerializableAWSCredentials]
-  ) extends ReceiverInputDStream[Array[Byte]](_ssc) {
+  ) extends ReceiverInputDStream[T](_ssc) {
 
   private[streaming]
-  override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = {
+  override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = {
 
     // This returns true even for when blockInfos is empty
     val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty)
@@ -56,6 +60,7 @@ private[kinesis] class KinesisInputDStream(
         context.sc, regionName, endpointUrl, blockIds, seqNumRanges,
         isBlockIdValid = isBlockIdValid,
         retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
+        messageHandler = messageHandler,
         awsCredentialsOption = awsCredentialsOption)
     } else {
       logWarning("Kinesis sequence number information was not present with some block metadata," +
@@ -64,8 +69,8 @@ private[kinesis] class KinesisInputDStream(
     }
   }
 
-  override def getReceiver(): Receiver[Array[Byte]] = {
+  override def getReceiver(): Receiver[T] = {
     new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
-      checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption)
+      checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 6e0988c..134d627 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -80,7 +80,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
  * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies
  *                             the credentials
  */
-private[kinesis] class KinesisReceiver(
+private[kinesis] class KinesisReceiver[T](
     val streamName: String,
     endpointUrl: String,
     regionName: String,
@@ -88,8 +88,9 @@ private[kinesis] class KinesisReceiver(
     checkpointAppName: String,
     checkpointInterval: Duration,
     storageLevel: StorageLevel,
-    awsCredentialsOption: Option[SerializableAWSCredentials]
-  ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver =>
+    messageHandler: Record => T,
+    awsCredentialsOption: Option[SerializableAWSCredentials])
+  extends Receiver[T](storageLevel) with Logging { receiver =>
 
   /*
    * =================================================================================
@@ -202,12 +203,7 @@ private[kinesis] class KinesisReceiver(
   /** Add records of the given shard to the current block being generated */
   private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = {
     if (records.size > 0) {
-      val dataIterator = records.iterator().asScala.map { record =>
-        val byteBuffer = record.getData()
-        val byteArray = new Array[Byte](byteBuffer.remaining())
-        byteBuffer.get(byteArray)
-        byteArray
-      }
+      val dataIterator = records.iterator().asScala.map(messageHandler)
       val metadata = SequenceNumberRange(streamName, shardId,
         records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber())
       blockGenerator.addMultipleDataWithCallback(dataIterator, metadata)
@@ -240,7 +236,7 @@ private[kinesis] class KinesisReceiver(
 
   /** Store the block along with its associated ranges */
   private def storeBlockWithRanges(
-      blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = {
+      blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = {
     val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId)
     if (rangesToReportOption.isEmpty) {
       stop("Error while storing block into Spark, could not find sequence number ranges " +
@@ -325,7 +321,7 @@ private[kinesis] class KinesisReceiver(
     /** Callback method called when a block is ready to be pushed / stored. */
     def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
       storeBlockWithRanges(blockId,
-        arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]])
+        arrayBuffer.asInstanceOf[mutable.ArrayBuffer[T]])
     }
 
     /** Callback called in case of any error in internal of the BlockGenerator */

http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index b240512..1d51787 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -41,8 +41,8 @@ import org.apache.spark.Logging
  * @param checkpointState represents the checkpoint state including the next checkpoint time.
  *   It's injected here for mocking purposes.
  */
-private[kinesis] class KinesisRecordProcessor(
-    receiver: KinesisReceiver,
+private[kinesis] class KinesisRecordProcessor[T](
+    receiver: KinesisReceiver[T],
     workerId: String,
     checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
index c799fad..2849fd8 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
@@ -16,15 +16,18 @@
  */
 package org.apache.spark.streaming.kinesis
 
+import scala.reflect.ClassTag
+
 import com.amazonaws.regions.RegionUtils
 import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import com.amazonaws.services.kinesis.model.Record
 
+import org.apache.spark.api.java.function.{Function => JFunction}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext}
 import org.apache.spark.streaming.dstream.ReceiverInputDStream
 import org.apache.spark.streaming.{Duration, StreamingContext}
 
-
 object KinesisUtils {
   /**
    * Create an input stream that pulls messages from a Kinesis stream.
@@ -52,6 +55,107 @@ object KinesisUtils {
    *                            details on the different types of checkpoints.
    * @param storageLevel Storage level to use for storing the received objects.
    *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
+   * @param messageHandler A custom message handler that can generate a generic output from a
+   *                       Kinesis `Record`, which contains both message data, and metadata.
+   */
+  def createStream[T: ClassTag](
+      ssc: StreamingContext,
+      kinesisAppName: String,
+      streamName: String,
+      endpointUrl: String,
+      regionName: String,
+      initialPositionInStream: InitialPositionInStream,
+      checkpointInterval: Duration,
+      storageLevel: StorageLevel,
+      messageHandler: Record => T): ReceiverInputDStream[T] = {
+    val cleanedHandler = ssc.sc.clean(messageHandler)
+    // Setting scope to override receiver stream's scope of "receiver stream"
+    ssc.withNamedScope("kinesis stream") {
+      new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
+        initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
+        cleanedHandler, None)
+    }
+  }
+
+  /**
+   * Create an input stream that pulls messages from a Kinesis stream.
+   * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+   *
+   * Note:
+   *  The given AWS credentials will get saved in DStream checkpoints if checkpointing
+   *  is enabled. Make sure that your checkpoint directory is secure.
+   *
+   * @param ssc StreamingContext object
+   * @param kinesisAppName  Kinesis application name used by the Kinesis Client Library
+   *                        (KCL) to update DynamoDB
+   * @param streamName   Kinesis stream name
+   * @param endpointUrl  Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+   * @param regionName   Name of region used by the Kinesis Client Library (KCL) to update
+   *                     DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+   * @param initialPositionInStream  In the absence of Kinesis checkpoint info, this is the
+   *                                 worker's initial starting position in the stream.
+   *                                 The values are either the beginning of the stream
+   *                                 per Kinesis' limit of 24 hours
+   *                                 (InitialPositionInStream.TRIM_HORIZON) or
+   *                                 the tip of the stream (InitialPositionInStream.LATEST).
+   * @param checkpointInterval  Checkpoint interval for Kinesis checkpointing.
+   *                            See the Kinesis Spark Streaming documentation for more
+   *                            details on the different types of checkpoints.
+   * @param storageLevel Storage level to use for storing the received objects.
+   *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
+   * @param messageHandler A custom message handler that can generate a generic output from a
+   *                       Kinesis `Record`, which contains both message data, and metadata.
+   * @param awsAccessKeyId  AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
+   * @param awsSecretKey  AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
+   */
+  // scalastyle:off
+  def createStream[T: ClassTag](
+      ssc: StreamingContext,
+      kinesisAppName: String,
+      streamName: String,
+      endpointUrl: String,
+      regionName: String,
+      initialPositionInStream: InitialPositionInStream,
+      checkpointInterval: Duration,
+      storageLevel: StorageLevel,
+      messageHandler: Record => T,
+      awsAccessKeyId: String,
+      awsSecretKey: String): ReceiverInputDStream[T] = {
+    // scalastyle:on
+    val cleanedHandler = ssc.sc.clean(messageHandler)
+    ssc.withNamedScope("kinesis stream") {
+      new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
+        initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
+        cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
+    }
+  }
+
+  /**
+   * Create an input stream that pulls messages from a Kinesis stream.
+   * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+   *
+   * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
+   * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
+   * gets the AWS credentials.
+   *
+   * @param ssc StreamingContext object
+   * @param kinesisAppName  Kinesis application name used by the Kinesis Client Library
+   *                        (KCL) to update DynamoDB
+   * @param streamName   Kinesis stream name
+   * @param endpointUrl  Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+   * @param regionName   Name of region used by the Kinesis Client Library (KCL) to update
+   *                     DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+   * @param initialPositionInStream  In the absence of Kinesis checkpoint info, this is the
+   *                                 worker's initial starting position in the stream.
+   *                                 The values are either the beginning of the stream
+   *                                 per Kinesis' limit of 24 hours
+   *                                 (InitialPositionInStream.TRIM_HORIZON) or
+   *                                 the tip of the stream (InitialPositionInStream.LATEST).
+   * @param checkpointInterval  Checkpoint interval for Kinesis checkpointing.
+   *                            See the Kinesis Spark Streaming documentation for more
+   *                            details on the different types of checkpoints.
+   * @param storageLevel Storage level to use for storing the received objects.
+   *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
    */
   def createStream(
       ssc: StreamingContext,
@@ -61,12 +165,12 @@ object KinesisUtils {
       regionName: String,
       initialPositionInStream: InitialPositionInStream,
       checkpointInterval: Duration,
-      storageLevel: StorageLevel
-    ): ReceiverInputDStream[Array[Byte]] = {
+      storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = {
     // Setting scope to override receiver stream's scope of "receiver stream"
     ssc.withNamedScope("kinesis stream") {
-      new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName),
-        initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, None)
+      new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
+        initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
+        defaultMessageHandler, None)
     }
   }
 
@@ -109,12 +213,11 @@ object KinesisUtils {
       checkpointInterval: Duration,
       storageLevel: StorageLevel,
       awsAccessKeyId: String,
-      awsSecretKey: String
-    ): ReceiverInputDStream[Array[Byte]] = {
+      awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = {
     ssc.withNamedScope("kinesis stream") {
-      new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName),
+      new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
         initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
-        Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
+        defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
     }
   }
 
@@ -156,8 +259,9 @@ object KinesisUtils {
       storageLevel: StorageLevel
     ): ReceiverInputDStream[Array[Byte]] = {
     ssc.withNamedScope("kinesis stream") {
-      new KinesisInputDStream(ssc, streamName, endpointUrl, getRegionByEndpoint(endpointUrl),
-        initialPositionInStream, ssc.sc.appName, checkpointInterval, storageLevel, None)
+      new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl,
+        getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName,
+        checkpointInterval, storageLevel, defaultMessageHandler, None)
     }
   }
 
@@ -187,6 +291,107 @@ object KinesisUtils {
    *                            details on the different types of checkpoints.
    * @param storageLevel Storage level to use for storing the received objects.
    *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
+   * @param messageHandler A custom message handler that can generate a generic output from a
+   *                       Kinesis `Record`, which contains both message data, and metadata.
+   * @param recordClass Class of the records in DStream
+   */
+  def createStream[T](
+      jssc: JavaStreamingContext,
+      kinesisAppName: String,
+      streamName: String,
+      endpointUrl: String,
+      regionName: String,
+      initialPositionInStream: InitialPositionInStream,
+      checkpointInterval: Duration,
+      storageLevel: StorageLevel,
+      messageHandler: JFunction[Record, T],
+      recordClass: Class[T]): JavaReceiverInputDStream[T] = {
+    implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
+    val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
+    createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+      initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler)
+  }
+
+  /**
+   * Create an input stream that pulls messages from a Kinesis stream.
+   * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+   *
+   * Note:
+   * The given AWS credentials will get saved in DStream checkpoints if checkpointing
+   * is enabled. Make sure that your checkpoint directory is secure.
+   *
+   * @param jssc Java StreamingContext object
+   * @param kinesisAppName  Kinesis application name used by the Kinesis Client Library
+   *                        (KCL) to update DynamoDB
+   * @param streamName   Kinesis stream name
+   * @param endpointUrl  Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+   * @param regionName   Name of region used by the Kinesis Client Library (KCL) to update
+   *                     DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+   * @param initialPositionInStream  In the absence of Kinesis checkpoint info, this is the
+   *                                 worker's initial starting position in the stream.
+   *                                 The values are either the beginning of the stream
+   *                                 per Kinesis' limit of 24 hours
+   *                                 (InitialPositionInStream.TRIM_HORIZON) or
+   *                                 the tip of the stream (InitialPositionInStream.LATEST).
+   * @param checkpointInterval  Checkpoint interval for Kinesis checkpointing.
+   *                            See the Kinesis Spark Streaming documentation for more
+   *                            details on the different types of checkpoints.
+   * @param storageLevel Storage level to use for storing the received objects.
+   *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
+   * @param messageHandler A custom message handler that can generate a generic output from a
+   *                       Kinesis `Record`, which contains both message data, and metadata.
+   * @param recordClass Class of the records in DStream
+   * @param awsAccessKeyId  AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
+   * @param awsSecretKey  AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
+   */
+  // scalastyle:off
+  def createStream[T](
+      jssc: JavaStreamingContext,
+      kinesisAppName: String,
+      streamName: String,
+      endpointUrl: String,
+      regionName: String,
+      initialPositionInStream: InitialPositionInStream,
+      checkpointInterval: Duration,
+      storageLevel: StorageLevel,
+      messageHandler: JFunction[Record, T],
+      recordClass: Class[T],
+      awsAccessKeyId: String,
+      awsSecretKey: String): JavaReceiverInputDStream[T] = {
+    // scalastyle:on
+    implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
+    val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
+    createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+      initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler,
+      awsAccessKeyId, awsSecretKey)
+  }
+
+  /**
+   * Create an input stream that pulls messages from a Kinesis stream.
+   * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+   *
+   * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
+   * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
+   * gets the AWS credentials.
+   *
+   * @param jssc Java StreamingContext object
+   * @param kinesisAppName  Kinesis application name used by the Kinesis Client Library
+   *                        (KCL) to update DynamoDB
+   * @param streamName   Kinesis stream name
+   * @param endpointUrl  Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+   * @param regionName   Name of region used by the Kinesis Client Library (KCL) to update
+   *                     DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+   * @param initialPositionInStream  In the absence of Kinesis checkpoint info, this is the
+   *                                 worker's initial starting position in the stream.
+   *                                 The values are either the beginning of the stream
+   *                                 per Kinesis' limit of 24 hours
+   *                                 (InitialPositionInStream.TRIM_HORIZON) or
+   *                                 the tip of the stream (InitialPositionInStream.LATEST).
+   * @param checkpointInterval  Checkpoint interval for Kinesis checkpointing.
+   *                            See the Kinesis Spark Streaming documentation for more
+   *                            details on the different types of checkpoints.
+   * @param storageLevel Storage level to use for storing the received objects.
+   *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
    */
   def createStream(
       jssc: JavaStreamingContext,
@@ -198,8 +403,8 @@ object KinesisUtils {
       checkpointInterval: Duration,
       storageLevel: StorageLevel
     ): JavaReceiverInputDStream[Array[Byte]] = {
-    createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
-      initialPositionInStream, checkpointInterval, storageLevel)
+    createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+      initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_))
   }
 
   /**
@@ -241,10 +446,10 @@ object KinesisUtils {
       checkpointInterval: Duration,
       storageLevel: StorageLevel,
       awsAccessKeyId: String,
-      awsSecretKey: String
-    ): JavaReceiverInputDStream[Array[Byte]] = {
-    createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
-        initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey)
+      awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = {
+    createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+      initialPositionInStream, checkpointInterval, storageLevel,
+      defaultMessageHandler(_), awsAccessKeyId, awsSecretKey)
   }
 
   /**
@@ -297,6 +502,14 @@ object KinesisUtils {
       throw new IllegalArgumentException(s"Region name '$regionName' is not valid")
     }
   }
+
+  private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = {
+    if (record == null) return null
+    val byteBuffer = record.getData()
+    val byteArray = new Array[Byte](byteBuffer.remaining())
+    byteBuffer.get(byteArray)
+    byteArray
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
index 87954a3..3f0f679 100644
--- a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
+++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
@@ -17,14 +17,19 @@
 
 package org.apache.spark.streaming.kinesis;
 
+import com.amazonaws.services.kinesis.model.Record;
+import org.junit.Test;
+
+import org.apache.spark.api.java.function.Function;
 import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.streaming.Duration;
 import org.apache.spark.streaming.LocalJavaStreamingContext;
 import org.apache.spark.streaming.api.java.JavaDStream;
-import org.junit.Test;
 
 import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
 
+import java.nio.ByteBuffer;
+
 /**
  * Demonstrate the use of the KinesisUtils Java API
  */
@@ -33,9 +38,27 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext {
   public void testKinesisStream() {
     // Tests the API, does not actually test data receiving
     JavaDStream<byte[]> kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream",
-        "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), 
+        "https://kinesis.us-west-2.amazonaws.com", new Duration(2000),
         InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2());
-    
+
+    ssc.stop();
+  }
+
+
+  private static Function<Record, String> handler = new Function<Record, String>() {
+    @Override
+    public String call(Record record) {
+      return record.getPartitionKey() + "-" + record.getSequenceNumber();
+    }
+  };
+
+  @Test
+  public void testCustomHandler() {
+    // Tests the API, does not actually test data receiving
+    JavaDStream<String> kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream",
+        "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST,
+        new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class);
+
     ssc.stop();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
index a89e562..9f9e146 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -73,22 +73,22 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
 
   testIfEnabled("Basic reading from Kinesis") {
     // Verify all data using multiple ranges in a single RDD partition
-    val receivedData1 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
-      fakeBlockIds(1),
+    val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+      testUtils.endpointUrl, fakeBlockIds(1),
       Array(SequenceNumberRanges(allRanges.toArray))
     ).map { bytes => new String(bytes).toInt }.collect()
     assert(receivedData1.toSet === testData.toSet)
 
     // Verify all data using one range in each of the multiple RDD partitions
-    val receivedData2 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
-      fakeBlockIds(allRanges.size),
+    val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+      testUtils.endpointUrl, fakeBlockIds(allRanges.size),
       allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
     ).map { bytes => new String(bytes).toInt }.collect()
     assert(receivedData2.toSet === testData.toSet)
 
     // Verify ordering within each partition
-    val receivedData3 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
-      fakeBlockIds(allRanges.size),
+    val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+      testUtils.endpointUrl, fakeBlockIds(allRanges.size),
       allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
     ).map { bytes => new String(bytes).toInt }.collectPartitions()
     assert(receivedData3.length === allRanges.size)
@@ -209,7 +209,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
       }, "Incorrect configuration of RDD, unexpected ranges set"
     )
 
-    val rdd = new KinesisBackedBlockRDD(
+    val rdd = new KinesisBackedBlockRDD[Array[Byte]](
       sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges)
     val collectedData = rdd.map { bytes =>
       new String(bytes).toInt
@@ -223,7 +223,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
     if (testIsBlockValid) {
       require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager")
       require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis")
-      val rdd2 = new KinesisBackedBlockRDD(
+      val rdd2 = new KinesisBackedBlockRDD[Array[Byte]](
         sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges,
         isBlockIdValid = Array.fill(blockIds.length)(false))
       intercept[SparkException] {

http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 3d136ae..17ab444 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -52,14 +52,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
   record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8)))
   val batch = Arrays.asList(record1, record2)
 
-  var receiverMock: KinesisReceiver = _
+  var receiverMock: KinesisReceiver[Array[Byte]] = _
   var checkpointerMock: IRecordProcessorCheckpointer = _
   var checkpointClockMock: ManualClock = _
   var checkpointStateMock: KinesisCheckpointState = _
   var currentClockMock: Clock = _
 
   override def beforeFunction(): Unit = {
-    receiverMock = mock[KinesisReceiver]
+    receiverMock = mock[KinesisReceiver[Array[Byte]]]
     checkpointerMock = mock[IRecordProcessorCheckpointer]
     checkpointClockMock = mock[ManualClock]
     checkpointStateMock = mock[KinesisCheckpointState]

http://git-wip-us.apache.org/repos/asf/spark/blob/63accc79/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
----------------------------------------------------------------------
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 1177dc7..ba84e55 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -24,6 +24,7 @@ import scala.util.Random
 
 import com.amazonaws.regions.RegionUtils
 import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import com.amazonaws.services.kinesis.model.Record
 import org.scalatest.Matchers._
 import org.scalatest.concurrent.Eventually
 import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
@@ -31,6 +32,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.{StorageLevel, StreamBlockId}
 import org.apache.spark.streaming._
+import org.apache.spark.streaming.dstream.ReceiverInputDStream
 import org.apache.spark.streaming.kinesis.KinesisTestUtils._
 import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
 import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
@@ -113,9 +115,9 @@ class KinesisStreamSuite extends KinesisFunSuite
     val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream",
       dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2),
       StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey)
-    assert(inputStream.isInstanceOf[KinesisInputDStream])
+    assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]])
 
-    val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream]
+    val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]]
     val time = Time(1000)
 
     // Generate block info data for testing
@@ -134,8 +136,8 @@ class KinesisStreamSuite extends KinesisFunSuite
     // Verify that the generated KinesisBackedBlockRDD has the all the right information
     val blockInfos = Seq(blockInfo1, blockInfo2)
     val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos)
-    nonEmptyRDD shouldBe a [KinesisBackedBlockRDD]
-    val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD]
+    nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
+    val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
     assert(kinesisRDD.regionName === dummyRegionName)
     assert(kinesisRDD.endpointUrl === dummyEndpointUrl)
     assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds)
@@ -151,7 +153,7 @@ class KinesisStreamSuite extends KinesisFunSuite
 
     // Verify that KinesisBackedBlockRDD is generated even when there are no blocks
     val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty)
-    emptyRDD shouldBe a [KinesisBackedBlockRDD]
+    emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
     emptyRDD.partitions shouldBe empty
 
     // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid
@@ -192,6 +194,32 @@ class KinesisStreamSuite extends KinesisFunSuite
     ssc.stop(stopSparkContext = false)
   }
 
+  testIfEnabled("custom message handling") {
+    val awsCredentials = KinesisTestUtils.getAWSCredentials()
+    def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5
+    val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
+      testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
+      Seconds(10), StorageLevel.MEMORY_ONLY, addFive,
+      awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
+
+    stream shouldBe a [ReceiverInputDStream[Int]]
+
+    val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
+    stream.foreachRDD { rdd =>
+      collected ++= rdd.collect()
+      logInfo("Collected = " + rdd.collect().toSeq.mkString(", "))
+    }
+    ssc.start()
+
+    val testData = 1 to 10
+    eventually(timeout(120 seconds), interval(10 second)) {
+      testUtils.pushData(testData)
+      val modData = testData.map(_ + 5)
+      assert(collected === modData.toSet, "\nData received does not match data sent")
+    }
+    ssc.stop(stopSparkContext = false)
+  }
+
   testIfEnabled("failure recovery") {
     val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
     val checkpointDir = Utils.createTempDir().getAbsolutePath
@@ -210,7 +238,7 @@ class KinesisStreamSuite extends KinesisFunSuite
 
     // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch
     kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => {
-      val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD]
+      val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
       val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq
       collectedData(time) = (kRdd.arrayOfseqNumberRanges, data)
     })
@@ -243,10 +271,10 @@ class KinesisStreamSuite extends KinesisFunSuite
     times.foreach { time =>
       val (arrayOfSeqNumRanges, data) = collectedData(time)
       val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]]
-      rdd shouldBe a [KinesisBackedBlockRDD]
+      rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
 
       // Verify the recovered sequence ranges
-      val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD]
+      val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
       assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size)
       arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) =>
         assert(expected.ranges.toSeq === found.ranges.toSeq)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org