You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@bahir.apache.org by lr...@apache.org on 2018/12/15 21:39:37 UTC

[2/3] bahir git commit: [BAHIR-183] Using HDFS for saving message for mqtt source.

[BAHIR-183] Using HDFS for saving message for mqtt source.

Closes #78


Project: http://git-wip-us.apache.org/repos/asf/bahir/repo
Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/172d7096
Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/172d7096
Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/172d7096

Branch: refs/heads/master
Commit: 172d7096147cd0be70687af893a4d71380ce47bf
Parents: 63878bf
Author: wangyanlin01 <wa...@baidu.com>
Authored: Sun Dec 2 11:00:21 2018 +0800
Committer: Luciano Resende <lr...@apache.org>
Committed: Sat Dec 15 18:39:19 2018 -0300

----------------------------------------------------------------------
 sql-streaming-mqtt/pom.xml                      |  12 +
 ....apache.spark.sql.sources.DataSourceRegister |   3 +-
 .../sql/streaming/mqtt/CachedMQTTClient.scala   |   2 +-
 .../sql/streaming/mqtt/MQTTStreamSink.scala     |   2 +-
 .../sql/streaming/mqtt/MQTTStreamSource.scala   |   2 +-
 .../bahir/sql/streaming/mqtt/MQTTUtils.scala    |  15 +-
 .../spark/sql/mqtt/HDFSMQTTSourceProvider.scala |  64 +++
 .../sql/mqtt/HdfsBasedMQTTStreamSource.scala    | 400 +++++++++++++++++++
 .../mqtt/HDFSBasedMQTTStreamSourceSuite.scala   | 198 +++++++++
 9 files changed, 689 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/pom.xml
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/pom.xml b/sql-streaming-mqtt/pom.xml
index 63497dc..05a3fff 100644
--- a/sql-streaming-mqtt/pom.xml
+++ b/sql-streaming-mqtt/pom.xml
@@ -85,6 +85,18 @@
       <version>5.13.3</version>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.apache.hadoop</groupId>
+      <artifactId>hadoop-hdfs</artifactId>
+      <version>2.6.5</version>
+      <classifier>tests</classifier>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.hadoop</groupId>
+      <artifactId>hadoop-common</artifactId>
+      <version>2.6.5</version>
+      <classifier>tests</classifier>
+    </dependency>
   </dependencies>
   <build>
     <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>

http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index d3899e6..1920a6b 100644
--- a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -16,4 +16,5 @@
 #
 
 org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider
-org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider
\ No newline at end of file
+org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider
+org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala
index fed2601..8925e93 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala
@@ -66,7 +66,7 @@ private[mqtt] object CachedMQTTClient extends Logging {
 
   private def createMqttClient(config: Map[String, String]):
       (MqttClient, MqttClientPersistence) = {
-    val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _) =
+    val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _, _, _, _) =
       MQTTUtils.parseConfigParams(config)
     val client = new MqttClient(brokerUrl, clientId, persistence)
     val callback = new MqttCallbackExtended() {

http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala
index f449e57..846765c 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala
@@ -52,7 +52,7 @@ class MQTTStreamWriter (schema: StructType, parameters: DataSourceOptions)
 
   initialize()
   private def initialize(): Unit = {
-    val (_, _, topic_, _, _, qos_) = MQTTUtils.parseConfigParams(
+    val (_, _, topic_, _, _, qos_, _, _, _) = MQTTUtils.parseConfigParams(
       collection.immutable.HashMap() ++ parameters.asMap().asScala
     )
     topic = topic_

http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
index 98bc60e..a40ff51 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
@@ -244,7 +244,7 @@ class MQTTStreamSourceProvider extends DataSourceV2
     }
 
     import scala.collection.JavaConverters._
-    val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos) =
+    val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos, _, _, _) =
       MQTTUtils.parseConfigParams(collection.immutable.HashMap() ++ parameters.asMap().asScala)
 
     new MQTTStreamSource(parameters, brokerUrl, persistence, topic, clientId,

http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala
index f0a6f1a..9df46bc 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala
@@ -26,8 +26,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 
 import org.apache.bahir.utils.Logging
 
-
-private[mqtt] object MQTTUtils extends Logging {
+object MQTTUtils extends Logging {
   // Since data source configuration properties are case-insensitive,
   // we have to introduce our own keys. Also, good for vendor independence.
   private[mqtt] val sslParamMapping = Map(
@@ -45,8 +44,8 @@ private[mqtt] object MQTTUtils extends Logging {
     "ssl.trust.manager" -> "com.ibm.ssl.trustManager"
   )
 
-  private[mqtt] def parseConfigParams(config: Map[String, String]):
-      (String, String, String, MqttClientPersistence, MqttConnectOptions, Int) = {
+  def parseConfigParams(config: Map[String, String]):
+      (String, String, String, MqttClientPersistence, MqttConnectOptions, Int, Long, Long, Int) = {
     def e(s: String) = new IllegalArgumentException(s)
     val parameters = CaseInsensitiveMap(config)
 
@@ -84,6 +83,11 @@ private[mqtt] object MQTTUtils extends Logging {
     val autoReconnect: Boolean = parameters.getOrElse("autoReconnect", "false").toBoolean
     val maxInflight: Int = parameters.getOrElse("maxInflight", "60").toInt
 
+    val maxBatchMessageNum = parameters.getOrElse("maxBatchMessageNum", s"${Long.MaxValue}").toLong
+    val maxBatchMessageSize = parameters.getOrElse("maxBatchMessageSize",
+      s"${Long.MaxValue}").toLong
+    val maxRetryNumber = parameters.getOrElse("maxRetryNum", "3").toInt
+
     val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions()
     mqttConnectOptions.setAutomaticReconnect(autoReconnect)
     mqttConnectOptions.setCleanSession(cleanSession)
@@ -105,6 +109,7 @@ private[mqtt] object MQTTUtils extends Logging {
     })
     mqttConnectOptions.setSSLProperties(sslProperties)
 
-    (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos)
+    (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos,
+      maxBatchMessageNum, maxBatchMessageSize, maxRetryNumber)
   }
 }

http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala
new file mode 100644
index 0000000..f38d842
--- /dev/null
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.mqtt
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.streaming.Source
+import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
+import org.apache.spark.sql.types.StructType
+
+import org.apache.bahir.sql.streaming.mqtt.{MQTTStreamConstants, MQTTUtils}
+
+/**
+ * The provider class for creating MQTT source.
+ * This provider throw IllegalArgumentException if  'brokerUrl' or 'topic' parameter
+ * is not set in options.
+ */
+class HDFSMQTTSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging {
+
+  override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType],
+    providerName: String, parameters: Map[String, String]): (String, StructType) = {
+    ("hdfs-mqtt", MQTTStreamConstants.SCHEMA_DEFAULT)
+  }
+
+  override def createSource(sqlContext: SQLContext, metadataPath: String,
+    schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = {
+
+    val parsedResult = MQTTUtils.parseConfigParams(parameters)
+
+    new HdfsBasedMQTTStreamSource(
+      sqlContext,
+      metadataPath,
+      parsedResult._1, // brokerUrl
+      parsedResult._2, // clientId
+      parsedResult._3, // topic
+      parsedResult._5, // mqttConnectionOptions
+      parsedResult._6, // qos
+      parsedResult._7, // maxBatchMessageNum
+      parsedResult._8, // maxBatchMessageSize
+      parsedResult._9  // maxRetryNum
+    )
+  }
+
+  override def shortName(): String = "hdfs-mqtt"
+}
+
+object HDFSMQTTSourceProvider {
+  val SEP = "##"
+}

http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala
new file mode 100644
index 0000000..e6e202b
--- /dev/null
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.mqtt
+
+import java.io.IOException
+import java.sql.Timestamp
+import java.util.Calendar
+import java.util.concurrent.locks.{Lock, ReentrantLock}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path, PathFilter}
+import org.eclipse.paho.client.mqttv3._
+import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.FileContextManager
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+
+import org.apache.bahir.sql.streaming.mqtt.{LongOffset, MQTTStreamConstants}
+
+/**
+ * A Text based mqtt stream source, it interprets the payload of each incoming message by converting
+ * the bytes to String using Charset.defaultCharset as charset. Each value is associated with a
+ * timestamp of arrival of the message on the source. It can be used to operate a window on the
+ * incoming stream.
+ *
+ * @param sqlContext         Spark provided, SqlContext.
+ * @param metadataPath       meta data path
+ * @param brokerUrl          url MqttClient connects to.
+ * @param topic              topic MqttClient subscribes to.
+ * @param clientId           clientId, this client is assoicated with.
+ *                           Provide the same value to recover a stopped client.
+ * @param mqttConnectOptions an instance of MqttConnectOptions for this Source.
+ * @param qos                the maximum quality of service to subscribe each topic at.
+ *                           Messages published at a lower quality of service will be received
+ *                           at the published QoS. Messages published at a higher quality of
+ *                           service will be received using the QoS specified on the subscribe.
+ * @param maxBatchNumber     the max message number to process in one batch.
+ * @param maxBatchSize       the max total size in one batch, measured in bytes number.
+ */
+class HdfsBasedMQTTStreamSource(
+  sqlContext: SQLContext,
+  metadataPath: String,
+  brokerUrl: String,
+  clientId: String,
+  topic: String,
+  mqttConnectOptions: MqttConnectOptions,
+  qos: Int,
+  maxBatchNumber: Long = Long.MaxValue,
+  maxBatchSize: Long = Long.MaxValue,
+  maxRetryNumber: Int = 3
+) extends Source with Logging {
+
+  import HDFSMQTTSourceProvider.SEP
+
+  override def schema: StructType = MQTTStreamConstants.SCHEMA_DEFAULT
+
+  // Last batch offset file index
+  private var lastOffset: Long = -1L
+
+  // Current data file index to write messages.
+  private var currentMessageDataFileOffset: Long = 0L
+
+  // FileSystem instance for storing received messages.
+  private var fs: FileSystem = _
+  private var messageStoreOutputStream: FSDataOutputStream = _
+
+  // total message number received for current batch.
+  private var messageNumberForCurrentBatch: Int = 0
+  // total message size received for
+  private var messageSizeForCurrentBatch: Int = 0
+
+  private val minBatchesToRetain = sqlContext.sparkSession.sessionState.conf.minBatchesToRetain
+
+  // the consecutive fail number, cannot exceed the `maxRetryNumber`
+  private var consecutiveFailNum = 0
+
+  private var client: MqttClient = _
+
+  private val lock: Lock = new ReentrantLock()
+
+  private val hadoopConfig: Configuration = if (HdfsBasedMQTTStreamSource.hadoopConfig != null) {
+    logInfo("using setted hadoop configuration!")
+    HdfsBasedMQTTStreamSource.hadoopConfig
+  } else {
+    logInfo("create a new configuration.")
+    new Configuration()
+  }
+
+  private val rootCheckpointPath = {
+    val path = new Path(metadataPath).getParent.getParent.toUri.toString
+    logInfo(s"get rootCheckpointPath $path")
+    path
+  }
+
+  private val receivedDataPath = s"$rootCheckpointPath/receivedMessages"
+
+  // lazily init latest offset from offset WAL log
+  private lazy val recoveredLatestOffset = {
+    // the index of this source, parsing from metadata path
+    val currentSourceIndex = {
+      if (!metadataPath.isEmpty) {
+        metadataPath.substring(metadataPath.lastIndexOf("/") + 1).toInt
+      } else {
+        -1
+      }
+    }
+    if (currentSourceIndex >= 0) {
+      val offsetLog = new OffsetSeqLog(sqlContext.sparkSession,
+        new Path(rootCheckpointPath, "offsets").toUri.toString)
+      // get the latest offset from WAL log
+      offsetLog.getLatest() match {
+        case Some((batchId, _)) =>
+          logInfo(s"get latest batch $batchId")
+          Some(batchId)
+        case None =>
+          logInfo("no offset avaliable in offset log")
+          None
+      }
+    } else {
+      logInfo("checkpoint path is not set")
+      None
+    }
+  }
+
+  initialize()
+
+  // Change data file if reach flow control threshold for one batch.
+  // Not thread safe.
+  private def startWriteNewDataFile(): Unit = {
+    if (messageStoreOutputStream != null) {
+      logInfo(s"Need to write a new data file,"
+        + s" close current data file index $currentMessageDataFileOffset")
+      messageStoreOutputStream.flush()
+      messageStoreOutputStream.hsync()
+      messageStoreOutputStream.close()
+      messageStoreOutputStream = null
+    }
+    currentMessageDataFileOffset += 1
+    messageSizeForCurrentBatch = 0
+    messageNumberForCurrentBatch = 0
+    messageStoreOutputStream = null
+  }
+
+  // not thread safe
+  private def addReceivedMessageInfo(messageNum: Int, messageSize: Int): Unit = {
+    messageSizeForCurrentBatch += messageSize
+    messageNumberForCurrentBatch += messageNum
+  }
+
+  // not thread safe
+  private def hasNewMessageForCurrentBatch(): Boolean = {
+    currentMessageDataFileOffset > lastOffset + 1 || messageNumberForCurrentBatch > 0
+  }
+
+  private def withLock[T](body: => T): T = {
+    lock.lock()
+    try body
+    finally lock.unlock()
+  }
+
+  private def initialize(): Unit = {
+
+    // recover lastOffset from WAL log
+    if (recoveredLatestOffset.nonEmpty) {
+      lastOffset = recoveredLatestOffset.get
+      logInfo(s"Recover lastOffset value ${lastOffset}")
+    }
+
+    fs = FileSystem.get(hadoopConfig)
+
+    // recover message data file offset from hdfs
+    val dataPath = new Path(receivedDataPath)
+    if (fs.exists(dataPath)) {
+      val fileManager = new FileContextManager(dataPath, hadoopConfig)
+      val dataFileIndexs = fileManager.list(dataPath, new PathFilter {
+        private def isBatchFile(path: Path) = {
+          try {
+            path.getName.toLong
+            true
+          } catch {
+            case _: NumberFormatException => false
+          }
+        }
+
+        override def accept(path: Path): Boolean = isBatchFile(path)
+      }).map(_.getPath.getName.toLong)
+      if (dataFileIndexs.nonEmpty) {
+        currentMessageDataFileOffset = dataFileIndexs.max + 1
+        assert(currentMessageDataFileOffset >= lastOffset + 1,
+          s"Recovered invalid message data file offset $currentMessageDataFileOffset,"
+            + s"do not match with lastOffset $lastOffset")
+        logInfo(s"Recovered last message data file offset: ${currentMessageDataFileOffset - 1}, "
+          + s"start from $currentMessageDataFileOffset")
+      } else {
+        logInfo("No old data file exist, start data file index from 0")
+        currentMessageDataFileOffset = 0
+      }
+    } else {
+      logInfo(s"Create data dir $receivedDataPath, start data file index from 0")
+      fs.mkdirs(dataPath)
+      currentMessageDataFileOffset = 0
+    }
+
+    client = new MqttClient(brokerUrl, clientId, new MemoryPersistence())
+
+    val callback = new MqttCallbackExtended() {
+
+      override def messageArrived(topic: String, message: MqttMessage): Unit = {
+        withLock[Unit] {
+          val messageSize = message.getPayload.size
+          // check if have reached the max number or max size for current batch.
+          if (messageNumberForCurrentBatch + 1 > maxBatchNumber
+            || messageSizeForCurrentBatch + messageSize > maxBatchSize) {
+            startWriteNewDataFile()
+          }
+          // write message content to data file
+          if (messageStoreOutputStream == null) {
+            val path = new Path(s"${receivedDataPath}/${currentMessageDataFileOffset}")
+            if (fs.createNewFile(path)) {
+              logInfo(s"Create new message data file ${path.toUri.toString} success!")
+            } else {
+              throw new IOException(s"${path.toUri.toString} already exist,"
+                + s"make sure do use unique checkpoint path for each app.")
+            }
+            messageStoreOutputStream = fs.append(path)
+          }
+
+          messageStoreOutputStream.writeBytes(s"${message.getId}${SEP}")
+          messageStoreOutputStream.writeBytes(s"${topic}${SEP}")
+          val timestamp = Calendar.getInstance().getTimeInMillis().toString
+          messageStoreOutputStream.writeBytes(s"${timestamp}${SEP}")
+          messageStoreOutputStream.write(message.getPayload())
+          messageStoreOutputStream.writeBytes("\n")
+          addReceivedMessageInfo(1, messageSize)
+          consecutiveFailNum = 0
+          logInfo(s"Message arrived, topic: $topic, message payload $message, "
+            + s"messageId: ${message.getId}, message size: ${messageSize}")
+        }
+      }
+
+      override def deliveryComplete(token: IMqttDeliveryToken): Unit = {
+        // callback for publisher, no need here.
+      }
+
+      override def connectionLost(cause: Throwable): Unit = {
+        // auto reconnection is enabled, so just add a log here.
+        withLock[Unit] {
+          consecutiveFailNum += 1
+          logWarning(s"Connection to mqtt server lost, "
+            + s"consecutive fail number $consecutiveFailNum", cause)
+        }
+      }
+
+      override def connectComplete(reconnect: Boolean, serverURI: String): Unit = {
+        logInfo(s"Connect complete $serverURI. Is it a reconnect?: $reconnect")
+      }
+    }
+    client.setCallback(callback)
+    client.connect(mqttConnectOptions)
+    client.subscribe(topic, qos)
+  }
+
+  /** Stop this source and free any resources it has allocated. */
+  override def stop(): Unit = {
+    logInfo("Stop mqtt source.")
+    client.disconnect()
+    client.close()
+    withLock[Unit] {
+      if (messageStoreOutputStream != null) {
+        messageStoreOutputStream.hflush()
+        messageStoreOutputStream.hsync()
+        messageStoreOutputStream.close()
+        messageStoreOutputStream = null
+      }
+      fs.close()
+    }
+  }
+
+  /** Returns the maximum available offset for this source. */
+  override def getOffset: Option[Offset] = {
+    withLock[Option[Offset]] {
+      assert(consecutiveFailNum < maxRetryNumber,
+        s"Write message data fail continuously for ${maxRetryNumber} times.")
+      val result = if (!hasNewMessageForCurrentBatch()) {
+        if (lastOffset == -1) {
+          // first submit and no message has arrived.
+          None
+        } else {
+          // no message has arrived for this batch.
+          Some(LongOffset(lastOffset))
+        }
+      } else {
+        // check if currently write the batch to be executed.
+        if (currentMessageDataFileOffset == lastOffset + 1) {
+          startWriteNewDataFile()
+        }
+        lastOffset += 1
+        Some(LongOffset(lastOffset))
+      }
+      logInfo(s"getOffset result $result")
+      result
+    }
+  }
+
+  /**
+   * Returns the data that is between the offsets (`start`, `end`].
+   * The batch return the data in file ${checkpointPath}/receivedMessages/${end}.
+   * `Start` and `end` value have the relationship: `end value` = `start valud` + 1,
+   * if `start` is not None.
+   */
+  override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+    withLock[Unit]{
+      assert(consecutiveFailNum < maxRetryNumber,
+        s"Write message data fail continuously for ${maxRetryNumber} times.")
+    }
+    logInfo(s"getBatch with start = $start, end = $end")
+    val endIndex = getOffsetValue(end)
+    if (start.nonEmpty) {
+      val startIndex = getOffsetValue(start.get)
+      assert(startIndex + 1 == endIndex,
+        s"start offset: ${startIndex} and end offset: ${endIndex} do not match")
+    }
+    logTrace(s"Create a data frame using hdfs file $receivedDataPath/$endIndex")
+    val rdd = sqlContext.sparkContext.textFile(s"$receivedDataPath/$endIndex")
+      .map{case str =>
+        // calculate message in
+        val idIndex = str.indexOf(SEP)
+        val messageId = str.substring(0, idIndex).toInt
+        // get topic
+        var subStr = str.substring(idIndex + SEP.length)
+        val topicIndex = subStr.indexOf(SEP)
+        val topic = UTF8String.fromString(subStr.substring(0, topicIndex))
+        // get timestamp
+        subStr = subStr.substring(topicIndex + SEP.length)
+        val timestampIndex = subStr.indexOf(SEP)
+        /*
+        val timestamp = Timestamp.valueOf(
+          MQTTStreamConstants.DATE_FORMAT.format(subStr.substring(0, timestampIndex).toLong))
+          */
+        val timestamp = subStr.substring(0, timestampIndex).toLong
+        // get playload
+        subStr = subStr.substring(timestampIndex + SEP.length)
+        val payload = UTF8String.fromString(subStr).getBytes
+        InternalRow(messageId, topic, payload, timestamp)
+      }
+    sqlContext.internalCreateDataFrame(rdd, MQTTStreamConstants.SCHEMA_DEFAULT, true)
+  }
+
+  /**
+   * Remove the data file for the offset.
+   *
+   * @param end the end of offset that all data has been committed.
+   */
+  override def commit(end: Offset): Unit = {
+    val offsetValue = getOffsetValue(end)
+    if (offsetValue >= minBatchesToRetain) {
+      val deleteDataFileOffset = offsetValue - minBatchesToRetain
+      try {
+        fs.delete(new Path(s"$receivedDataPath/$deleteDataFileOffset"), false)
+        logInfo(s"Delete committed offset data file $deleteDataFileOffset success!")
+      } catch {
+        case e: Exception =>
+          logWarning(s"Delete committed offset data file $deleteDataFileOffset failed. ", e)
+      }
+    }
+  }
+
+  private def getOffsetValue(offset: Offset): Long = {
+    val offsetValue = offset match {
+      case o: LongOffset => o.offset
+      case so: SerializedOffset =>
+        so.json.toLong
+    }
+    offsetValue
+  }
+}
+object HdfsBasedMQTTStreamSource {
+
+  var hadoopConfig: Configuration = _
+}

http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala
new file mode 100644
index 0000000..777db16
--- /dev/null
+++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.sql.streaming.mqtt
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hdfs.MiniDFSCluster
+import org.apache.hadoop.security.Groups
+import org.eclipse.paho.client.mqttv3.MqttException
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
+import org.apache.spark.sql._
+import org.apache.spark.sql.mqtt.{HdfsBasedMQTTStreamSource, HDFSMQTTSourceProvider}
+import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQuery}
+
+import org.apache.bahir.utils.FileHelper
+
+class HDFSBasedMQTTStreamSourceSuite
+    extends SparkFunSuite
+    with SharedSparkContext
+    with BeforeAndAfter {
+
+  protected var mqttTestUtils: MQTTTestUtils = _
+  protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test/")
+  protected var hadoop: MiniDFSCluster = _
+
+  before {
+    tempDir.mkdirs()
+    if (!tempDir.exists()) {
+      throw new IllegalStateException("Unable to create temp directories.")
+    }
+    tempDir.deleteOnExit()
+    mqttTestUtils = new MQTTTestUtils(tempDir)
+    mqttTestUtils.setup()
+    hadoop = HDFSTestUtils.prepareHadoop()
+  }
+
+  after {
+    mqttTestUtils.teardown()
+    HDFSTestUtils.shutdownHadoop()
+    FileHelper.deleteFileQuietly(tempDir)
+  }
+
+  protected val tmpDir: String = tempDir.getAbsolutePath
+
+  protected def writeStreamResults(sqlContext: SQLContext, dataFrame: DataFrame): StreamingQuery = {
+    import sqlContext.implicits._
+    val query: StreamingQuery = dataFrame.selectExpr("CAST(payload AS STRING)").as[String]
+      .writeStream.format("csv").start(s"$tempDir/t.csv")
+    while (!query.status.isTriggerActive) {
+      Thread.sleep(20)
+    }
+    query
+  }
+
+  protected def readBackStreamingResults(sqlContext: SQLContext): mutable.Buffer[String] = {
+    import sqlContext.implicits._
+    val asList =
+      sqlContext.read
+        .csv(s"$tmpDir/t.csv").as[String]
+        .collectAsList().asScala
+    asList
+  }
+
+  protected def createStreamingDataFrame(dir: String = tmpDir): (SQLContext, DataFrame) = {
+
+    val sqlContext: SQLContext = SparkSession.builder()
+      .getOrCreate().sqlContext
+
+    sqlContext.setConf("spark.sql.streaming.checkpointLocation",
+      s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint")
+
+    val ds: DataStreamReader =
+      sqlContext.readStream.format("org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider")
+        .option("topic", "test").option("clientId", "clientId").option("connectionTimeout", "120")
+        .option("keepAlive", "1200").option("autoReconnect", "false")
+        .option("cleanSession", "true").option("QoS", "2")
+    val dataFrame = ds.load("tcp://" + mqttTestUtils.brokerUri)
+    (sqlContext, dataFrame)
+  }
+}
+
+object HDFSTestUtils {
+
+  private var hadoop: MiniDFSCluster = _
+
+  def prepareHadoop(): MiniDFSCluster = {
+    if (hadoop != null) {
+      hadoop
+    } else {
+      val baseDir = new File(System.getProperty("java.io.tmpdir") + "/hadoop").getAbsoluteFile
+      System.setProperty("HADOOP_USER_NAME", "test")
+      val conf = new Configuration
+      conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath)
+      conf.setBoolean("dfs.namenode.acls.enabled", true)
+      conf.setBoolean("dfs.permissions", true)
+      Groups.getUserToGroupsMappingService(conf)
+      val builder = new MiniDFSCluster.Builder(conf)
+      hadoop = builder.build
+      conf.set("fs.defaultFS", "hdfs://localhost:" + hadoop.getNameNodePort + "/")
+      HdfsBasedMQTTStreamSource.hadoopConfig = conf
+      hadoop
+    }
+  }
+
+  def shutdownHadoop(): Unit = {
+    if (null != hadoop) {
+      hadoop.shutdown(true)
+    }
+    hadoop = null
+  }
+}
+
+class BasicHDFSBasedMQTTSourceSuite extends HDFSBasedMQTTStreamSourceSuite {
+
+  test("basic usage") {
+
+    val sendMessage = "MQTT is a message queue."
+
+    val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame()
+
+    val query = writeStreamResults(sqlContext, dataFrame)
+    mqttTestUtils.publishData("test", sendMessage)
+    query.processAllAvailable()
+    query.awaitTermination(10000)
+
+    val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext)
+
+    assert(resultBuffer.size == 1)
+    assert(resultBuffer.head == sendMessage)
+  }
+
+  test("Send and receive 50 messages.") {
+
+    val sendMessage = "MQTT is a message queue."
+
+    val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame()
+
+    val q = writeStreamResults(sqlContext, dataFrame)
+
+    mqttTestUtils.publishData("test", sendMessage, 50)
+    q.processAllAvailable()
+    q.awaitTermination(10000)
+
+    val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext)
+
+    assert(resultBuffer.size == 50)
+    assert(resultBuffer.head == sendMessage)
+  }
+
+  test("no server up") {
+    val provider = new HDFSMQTTSourceProvider
+    val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext
+    intercept[MqttException] {
+      provider.createSource(
+        sqlContext,
+        s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint/0",
+        Some(MQTTStreamConstants.SCHEMA_DEFAULT),
+        "org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider",
+        Map("brokerUrl" -> "tcp://localhost:1881", "topic" -> "test")
+      )
+    }
+  }
+
+  test("params not provided.") {
+    val provider = new HDFSMQTTSourceProvider
+    val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext
+    intercept[IllegalArgumentException] {
+      provider.createSource(
+        sqlContext,
+        s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint/0",
+        Some(MQTTStreamConstants.SCHEMA_DEFAULT),
+        "org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider",
+        Map()
+      )
+    }
+  }
+}