You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@bahir.apache.org by lr...@apache.org on 2018/11/08 03:21:25 UTC

bahir git commit: [BAHIR-164][BAHIR-165] Port Mqtt sql source to datasource v2 API

Repository: bahir
Updated Branches:
  refs/heads/master 3a211a74c -> b3902bac6


[BAHIR-164][BAHIR-165] Port Mqtt sql source to datasource v2 API

Migrating Mqtt spark structured streaming connector to DatasourceV2 API.

Closes #65


Project: http://git-wip-us.apache.org/repos/asf/bahir/repo
Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/b3902bac
Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/b3902bac
Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/b3902bac

Branch: refs/heads/master
Commit: b3902bac67edc2134bcc2c755fadc5c60c8ae01c
Parents: 3a211a7
Author: Prashant Sharma <pr...@in.ibm.com>
Authored: Fri Apr 27 12:39:35 2018 +0530
Committer: Luciano Resende <lr...@apache.org>
Committed: Wed Nov 7 19:11:18 2018 -0800

----------------------------------------------------------------------
 pom.xml                                         |  18 +-
 .../streaming/akka/AkkaStreamSourceSuite.scala  |   2 +-
 sql-streaming-mqtt/README.md                    |  58 +++-
 .../streaming/mqtt/JavaMQTTStreamWordCount.java |   2 +-
 .../streaming/mqtt/MQTTStreamWordCount.scala    |   6 +-
 .../bahir/sql/streaming/mqtt/LongOffset.scala   |  54 ++++
 .../sql/streaming/mqtt/MQTTStreamSource.scala   | 284 ++++++++++++-------
 .../bahir/sql/streaming/mqtt/MessageStore.scala |  90 ++++--
 .../src/test/bin/test-BAHIR-83.sh               |  24 ++
 .../streaming/mqtt/LocalMessageStoreSuite.scala |   9 +-
 .../streaming/mqtt/MQTTStreamSourceSuite.scala  | 154 +++++-----
 .../sql/streaming/mqtt/MQTTTestUtils.scala      |  14 +-
 12 files changed, 462 insertions(+), 253 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 1edd641..8282346 100644
--- a/pom.xml
+++ b/pom.xml
@@ -77,7 +77,7 @@
   <modules>
     <module>sql-cloudant</module>
     <module>streaming-akka</module>
-    <module>sql-streaming-akka</module>
+    <!-- <module>sql-streaming-akka</module> Disabling akka sql module, until it is updated to run with datasource v2 API. -->
     <module>streaming-mqtt</module>
     <module>sql-streaming-mqtt</module>
     <module>streaming-twitter</module>
@@ -99,7 +99,7 @@
     <log4j.version>1.2.17</log4j.version>
 
     <!-- Spark version -->
-    <spark.version>2.2.2</spark.version>
+    <spark.version>2.3.0</spark.version>
 
     <!-- MQTT Client -->
     <mqtt.paho.client>1.1.0</mqtt.paho.client>
@@ -348,13 +348,13 @@
       <dependency>
         <groupId>org.scalatest</groupId>
         <artifactId>scalatest_${scala.binary.version}</artifactId>
-        <version>2.2.6</version>
+        <version>3.0.3</version>
         <scope>test</scope>
       </dependency>
       <dependency>
         <groupId>org.scalacheck</groupId>
         <artifactId>scalacheck_${scala.binary.version}</artifactId>
-        <version>1.12.5</version> <!-- 1.13.0 appears incompatible with scalatest 2.2.6 -->
+        <version>1.13.5</version>
         <scope>test</scope>
       </dependency>
 
@@ -407,7 +407,7 @@
         <plugin>
           <groupId>org.apache.maven.plugins</groupId>
           <artifactId>maven-enforcer-plugin</artifactId>
-          <version>1.4.1</version>
+          <version>3.0.0-M1</version>
           <executions>
             <execution>
               <id>enforce-versions</id>
@@ -433,6 +433,7 @@
                       -->
                       <exclude>org.jboss.netty</exclude>
                       <exclude>org.codehaus.groovy</exclude>
+                      <exclude>*:*_2.10</exclude>
                     </excludes>
                     <searchTransitive>true</searchTransitive>
                   </bannedDependencies>
@@ -482,7 +483,8 @@
         <plugin>
           <groupId>net.alchim31.maven</groupId>
           <artifactId>scala-maven-plugin</artifactId>
-          <version>3.3.1</version>
+          <!-- 3.3.1 won't work with zinc; fails to find javac from java.home -->
+          <version>3.2.2</version>
           <executions>
             <execution>
               <id>eclipse-add-source</id>
@@ -557,7 +559,7 @@
         <plugin>
           <groupId>org.apache.maven.plugins</groupId>
           <artifactId>maven-surefire-plugin</artifactId>
-          <version>2.19.1</version>
+          <version>2.20.1</version>
           <!-- Note config is repeated in scalatest config -->
           <configuration>
             <includes>
@@ -567,7 +569,7 @@
               <include>**/*Suite.java</include>
             </includes>
             <reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
-            <argLine>-Xmx3g -Xss4096k -XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine>
+            <argLine>-ea -Xmx3g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine>
             <environmentVariables>
               <!--
                 Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala b/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
index 5e9b86e..cdf629b 100644
--- a/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
+++ b/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
@@ -155,7 +155,7 @@ class StressTestAkkaSource extends AkkaStreamSourceSuite {
 
   // Run with -Xmx1024m
   // Default allowed payload size sent to an akka actor is 128000 bytes.
-  test("Send & Receive messages of size 128000 bytes.") {
+  ignore("Send & Receive messages of size 128000 bytes.") {
 
     val freeMemory: Long = Runtime.getRuntime.freeMemory()
 

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/README.md
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/README.md b/sql-streaming-mqtt/README.md
index 2cfbe0f..b7f0602 100644
--- a/sql-streaming-mqtt/README.md
+++ b/sql-streaming-mqtt/README.md
@@ -59,7 +59,9 @@ This source uses [Eclipse Paho Java Client](https://eclipse.org/paho/clients/jav
  * `connectionTimeout` Sets the connection timeout, a value of 0 is interpretted as wait until client connects. See `MqttConnectOptions.setConnectionTimeout` for more information.
  * `keepAlive` Same as `MqttConnectOptions.setKeepAliveInterval`.
  * `mqttVersion` Same as `MqttConnectOptions.setMqttVersion`.
-
+ * `maxInflight` Same as `MqttConnectOptions.setMaxInflight`
+ * `autoReconnect` Same as `MqttConnectOptions.setAutomaticReconnect`
+ 
 ### Scala API
 
 An example, for scala API to count words from incoming message stream. 
@@ -68,7 +70,7 @@ An example, for scala API to count words from incoming message stream.
     val lines = spark.readStream
       .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
       .option("topic", topic)
-      .load(brokerUrl).as[(String, Timestamp)]
+      .load(brokerUrl).selectExpr("CAST(payload AS STRING)").as[String]
 
     // Split the lines into words
     val words = lines.map(_._1).flatMap(_.split(" "))
@@ -95,7 +97,8 @@ An example, for Java API to count words from incoming message stream.
             .readStream()
             .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
             .option("topic", topic)
-            .load(brokerUrl).select("value").as(Encoders.STRING());
+            .load(brokerUrl)
+            .selectExpr("CAST(payload AS STRING)").as(Encoders.STRING());
 
     // Split the lines into words
     Dataset<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@@ -118,3 +121,52 @@ An example, for Java API to count words from incoming message stream.
 
 Please see `JavaMQTTStreamWordCount.java` for full example.
 
+## Best Practices.
+
+1. Turn Mqtt into a more reliable messaging service. 
+
+> *MQTT is a machine-to-machine (M2M)/"Internet of Things" connectivity protocol. It was designed as an extremely lightweight publish/subscribe messaging transport.*
+
+The design of Mqtt and the purpose it serves goes well together, but often in an application it is of utmost value to have reliability. Since mqtt is not a distributed message queue and thus does not offer the highest level of reliability features. It should be redirected via a kafka message queue to take advantage of a distributed message queue. In fact, using a kafka message queue offers a lot of possibilities including a single kafka topic subscribed to several mqtt sources and even a single mqtt stream publishing to multiple kafka topics. Kafka is a reliable and scalable message queue.
+
+2. Often the message payload is not of the default character encoding or contains binary that needs to be parsed using a particular parser. In such cases, spark mqtt payload should be processed using the external parser. For example:
+
+ * Scala API example:
+```scala
+    // Create DataFrame representing the stream of binary messages
+    val lines = spark.readStream
+      .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
+      .option("topic", topic)
+      .load(brokerUrl).select("payload").as[Array[Byte]].map(externalParser(_))
+```
+
+ * Java API example
+```java
+        // Create DataFrame representing the stream of binary messages
+        Dataset<byte[]> lines = spark
+                .readStream()
+                .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
+                .option("topic", topic)
+                .load(brokerUrl).selectExpr("CAST(payload AS BINARY)").as(Encoders.BINARY());
+
+        // Split the lines into words
+        Dataset<String> words = lines.map(new MapFunction<byte[], String>() {
+            @Override
+            public String call(byte[] bytes) throws Exception {
+                return new String(bytes); // Plug in external parser here.
+            }
+        }, Encoders.STRING()).flatMap(new FlatMapFunction<String, String>() {
+            @Override
+            public Iterator<String> call(String x) {
+                return Arrays.asList(x.split(" ")).iterator();
+            }
+        }, Encoders.STRING());
+
+```
+
+3. What is the solution for a situation when there are a large number of varied mqtt sources, each with different schema and throughput characteristics.
+
+Generally, one would create a lot of streaming pipelines to solve this problem. This would either require a very sophisticated scheduling setup or will waste a lot of resources, as it is not certain which stream is using more amount of data.
+
+The general solution is both less optimum and is more cumbersome to operate, with multiple moving parts incurs a high maintenance overall. As an alternative, in this situation, one can setup a single topic kafka-spark stream, where message from each of the varied stream contains a unique tag separating one from other streams. This way at the processing end, one can distinguish the message from one another and apply the right kind of decoding and processing. Similarly while storing, each message can be distinguished from others by a tag that distinguishes.
+

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java b/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
index 519d9a0..4e87c99 100644
--- a/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
+++ b/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
@@ -71,7 +71,7 @@ public final class JavaMQTTStreamWordCount {
                 .readStream()
                 .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
                 .option("topic", topic)
-                .load(brokerUrl).select("value").as(Encoders.STRING());
+                .load(brokerUrl).selectExpr("CAST(payload AS STRING)").as(Encoders.STRING());
 
         // Split the lines into words
         Dataset<String> words = lines.flatMap(new FlatMapFunction<String, String>() {

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala b/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
index 237a8fa..ee7de22 100644
--- a/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
+++ b/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
@@ -52,11 +52,11 @@ object MQTTStreamWordCount  {
     // Create DataFrame representing the stream of input lines from connection to mqtt server
     val lines = spark.readStream
       .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
-      .option("topic", topic)
-      .load(brokerUrl).as[(String, Timestamp)]
+      .option("topic", topic).option("persistence", "memory")
+      .load(brokerUrl).selectExpr("CAST(payload AS STRING)").as[String]
 
     // Split the lines into words
-    val words = lines.map(_._1).flatMap(_.split(" "))
+    val words = lines.flatMap(_.split(" "))
 
     // Generate running word count
     val wordCounts = words.groupBy("value").count()

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/LongOffset.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/LongOffset.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/LongOffset.scala
new file mode 100644
index 0000000..345b576
--- /dev/null
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/LongOffset.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.sql.streaming.mqtt
+
+import org.apache.spark.sql.execution.streaming.Offset
+import org.apache.spark.sql.execution.streaming.SerializedOffset
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
+
+/**
+ * A simple offset for sources that produce a single linear stream of data.
+ */
+case class LongOffset(offset: Long) extends OffsetV2 {
+
+  override val json = offset.toString
+
+  def +(increment: Long): LongOffset = new LongOffset(offset + increment)
+  def -(decrement: Long): LongOffset = new LongOffset(offset - decrement)
+}
+
+object LongOffset {
+
+  /**
+   * LongOffset factory from serialized offset.
+   *
+   * @return new LongOffset
+   */
+  def apply(offset: SerializedOffset) : LongOffset = new LongOffset(offset.json.toLong)
+
+  /**
+   * Convert generic Offset to LongOffset if possible.
+   *
+   * @return converted LongOffset
+   */
+  def convert(offset: Offset): Option[LongOffset] = offset match {
+    case lo: LongOffset => Some(lo)
+    case so: SerializedOffset => Some(LongOffset(so))
+    case _ => None
+  }
+}

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
index 1739ff3..2f75ee2 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
@@ -20,20 +20,23 @@ package org.apache.bahir.sql.streaming.mqtt
 import java.nio.charset.Charset
 import java.sql.Timestamp
 import java.text.SimpleDateFormat
-import java.util.Calendar
-import java.util.concurrent.CountDownLatch
+import java.util.{Calendar, Optional}
+import javax.annotation.concurrent.GuardedBy
 
+import scala.collection.JavaConverters._
 import scala.collection.concurrent.TrieMap
-import scala.collection.mutable.ArrayBuffer
-import scala.util.{Failure, Success, Try}
+import scala.collection.immutable.IndexedSeq
+import scala.collection.mutable.ListBuffer
 
 import org.eclipse.paho.client.mqttv3._
 import org.eclipse.paho.client.mqttv3.persist.{MemoryPersistence, MqttDefaultFilePersistence}
 
-import org.apache.spark.sql.{DataFrame, SQLContext}
-import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Source}
-import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
-import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+import org.apache.spark.sql._
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
+import org.apache.spark.sql.types._
 
 import org.apache.bahir.utils.Logging
 
@@ -42,15 +45,38 @@ object MQTTStreamConstants {
 
   val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
 
-  val SCHEMA_DEFAULT = StructType(StructField("value", StringType)
-    :: StructField("timestamp", TimestampType) :: Nil)
+  val SCHEMA_DEFAULT = StructType(StructField("id", IntegerType) :: StructField("topic",
+    StringType):: StructField("payload", BinaryType) :: StructField("timestamp", TimestampType) ::
+    Nil)
 }
 
+class MQTTMessage(m: MqttMessage, val topic: String) extends Serializable {
+
+  // TODO: make it configurable.
+  val timestamp: Timestamp = Timestamp.valueOf(
+    MQTTStreamConstants.DATE_FORMAT.format(Calendar.getInstance().getTime))
+  val duplicate = m.isDuplicate
+  val retained = m.isRetained
+  val qos = m.getQos
+  val id: Int = m.getId
+
+  val payload: Array[Byte] = m.getPayload
+
+  override def toString(): String = {
+    s"""MQTTMessage.
+       |Topic: ${this.topic}
+       |MessageID: ${this.id}
+       |QoS: ${this.qos}
+       |Payload: ${this.payload}
+       |Payload as string: ${new String(this.payload, Charset.defaultCharset())}
+       |isRetained: ${this.retained}
+       |isDuplicate: ${this.duplicate}
+       |TimeStamp: ${this.timestamp}
+     """.stripMargin
+  }
+}
 /**
- * A Text based mqtt stream source, it interprets the payload of each incoming message by converting
- * the bytes to String using Charset.defaultCharset as charset. Each value is associated with a
- * timestamp of arrival of the message on the source. It can be used to operate a window on the
- * incoming stream.
+ * A mqtt stream source.
  *
  * @param brokerUrl url MqttClient connects to.
  * @param persistence an instance of MqttClientPersistence. By default it is used for storing
@@ -59,53 +85,49 @@ object MQTTStreamConstants {
  * @param topic topic MqttClient subscribes to.
  * @param clientId clientId, this client is assoicated with. Provide the same value to recover
  *                 a stopped client.
- * @param messageParser parsing logic for processing incoming messages from Mqtt Server.
- * @param sqlContext Spark provided, SqlContext.
  * @param mqttConnectOptions an instance of MqttConnectOptions for this Source.
  * @param qos the maximum quality of service to subscribe each topic at.Messages published at
  *            a lower quality of service will be received at the published QoS. Messages
  *            published at a higher quality of service will be received using the QoS specified
  *            on the subscribe.
  */
-class MQTTTextStreamSource(brokerUrl: String, persistence: MqttClientPersistence,
-    topic: String, clientId: String, messageParser: Array[Byte] => (String, Timestamp),
-    sqlContext: SQLContext, mqttConnectOptions: MqttConnectOptions, qos: Int)
-  extends Source with Logging {
+class MQTTStreamSource(options: DataSourceOptions, brokerUrl: String, persistence:
+    MqttClientPersistence, topic: String, clientId: String,
+    mqttConnectOptions: MqttConnectOptions, qos: Int)
+  extends MicroBatchReader with Logging {
+
+  private var startOffset: OffsetV2 = _
+  private var endOffset: OffsetV2 = _
 
-  override def schema: StructType = MQTTStreamConstants.SCHEMA_DEFAULT
+  /* Older than last N messages, will not be checked for redelivery. */
+  val backLog = options.getInt("autopruning.backlog", 500)
 
-  private val store = new LocalMessageStore(persistence, sqlContext.sparkContext.getConf)
+  private val store = new LocalMessageStore(persistence)
 
-  private val messages = new TrieMap[Int, (String, Timestamp)]
+  private val messages = new TrieMap[Long, MQTTMessage]
 
-  private val initLock = new CountDownLatch(1)
+  @GuardedBy("this")
+  private var currentOffset: LongOffset = LongOffset(-1L)
 
-  private var offset = 0
+  @GuardedBy("this")
+  private var lastOffsetCommitted: LongOffset = LongOffset(-1L)
 
   private var client: MqttClient = _
 
-  private def fetchLastProcessedOffset(): Int = {
-    Try(store.maxProcessedOffset) match {
-      case Success(x) =>
-        log.info(s"Recovering from last stored offset $x")
-        x
-      case Failure(e) => 0
-    }
-  }
+  private[mqtt] def getCurrentOffset = currentOffset
 
   initialize()
   private def initialize(): Unit = {
 
     client = new MqttClient(brokerUrl, clientId, persistence)
-
     val callback = new MqttCallbackExtended() {
 
       override def messageArrived(topic_ : String, message: MqttMessage): Unit = synchronized {
-        initLock.await() // Wait for initialization to complete.
-        val temp = offset + 1
-        messages.put(temp, messageParser(message.getPayload))
-        offset = temp
-        log.trace(s"Message arrived, $topic_ $message")
+        val mqttMessage = new MQTTMessage(message, topic_)
+        val offset = currentOffset.offset + 1L
+        messages.put(offset, mqttMessage)
+        currentOffset = LongOffset(offset)
+        log.trace(s"Message arrived, $topic_ $mqttMessage")
       }
 
       override def deliveryComplete(token: IMqttDeliveryToken): Unit = {
@@ -121,116 +143,162 @@ class MQTTTextStreamSource(brokerUrl: String, persistence: MqttClientPersistence
     }
     client.setCallback(callback)
     client.connect(mqttConnectOptions)
-    client.subscribe(topic, qos)
     // It is not possible to initialize offset without `client.connect`
-    offset = fetchLastProcessedOffset()
-    initLock.countDown() // Release.
+    client.subscribe(topic, qos)
   }
 
-  /** Stop this source and free any resources it has allocated. */
-  override def stop(): Unit = {
-    client.disconnect()
-    persistence.close()
-    client.close()
+  override def setOffsetRange(
+      start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = synchronized {
+    startOffset = start.orElse(LongOffset(-1L))
+    endOffset = end.orElse(currentOffset)
   }
 
-  /** Returns the maximum available offset for this source. */
-  override def getOffset: Option[Offset] = {
-    if (offset == 0) {
-      None
-    } else {
-      Some(LongOffset(offset))
-    }
+  override def getStartOffset(): OffsetV2 = {
+    Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set"))
+  }
+
+  override def getEndOffset(): OffsetV2 = {
+    Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set"))
   }
 
-  /**
-   * Returns the data that is between the offsets (`start`, `end`]. When `start` is `None` then
-   * the batch should begin with the first available record. This method must always return the
-   * same data for a particular `start` and `end` pair.
-   */
-  override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized {
-    val startIndex = start.getOrElse(LongOffset(0L)).asInstanceOf[LongOffset].offset.toInt
-    val endIndex = end.asInstanceOf[LongOffset].offset.toInt
-    val data: ArrayBuffer[(String, Timestamp)] = ArrayBuffer.empty
-    // Move consumed messages to persistent store.
-    (startIndex + 1 to endIndex).foreach { id =>
-      val element: (String, Timestamp) = messages.getOrElse(id, store.retrieve(id))
-      data += element
-      store.store(id, element)
-      messages.remove(id, element)
+  override def deserializeOffset(json: String): OffsetV2 = {
+    LongOffset(json.toLong)
+  }
+
+  override def readSchema(): StructType = {
+    MQTTStreamConstants.SCHEMA_DEFAULT
+  }
+
+  override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
+    val rawList: IndexedSeq[MQTTMessage] = synchronized {
+      val sliceStart = LongOffset.convert(startOffset).get.offset + 1
+      val sliceEnd = LongOffset.convert(endOffset).get.offset + 1
+      for (i <- sliceStart until sliceEnd) yield messages(i)
+    }
+    val spark = SparkSession.getActiveSession.get
+    val numPartitions = spark.sparkContext.defaultParallelism
+
+    val slices = Array.fill(numPartitions)(new ListBuffer[MQTTMessage])
+    rawList.zipWithIndex.foreach { case (r, idx) =>
+      slices(idx % numPartitions).append(r)
     }
-    log.trace(s"Get Batch invoked, ${data.mkString}")
-    import sqlContext.implicits._
-    data.toDF("value", "timestamp")
+
+    (0 until numPartitions).map { i =>
+      val slice = slices(i)
+      new DataReaderFactory[Row] {
+        override def createDataReader(): DataReader[Row] = new DataReader[Row] {
+          private var currentIdx = -1
+
+          override def next(): Boolean = {
+            currentIdx += 1
+            currentIdx < slice.size
+          }
+
+          override def get(): Row = {
+            Row(slice(currentIdx).id, slice(currentIdx).topic,
+              slice(currentIdx).payload, slice(currentIdx).timestamp)
+          }
+
+          override def close(): Unit = {}
+        }
+      }
+    }.toList.asJava
   }
 
-}
+  override def commit(end: OffsetV2): Unit = synchronized {
+    val newOffset = LongOffset.convert(end).getOrElse(
+      sys.error(s"MQTTStreamSource.commit() received an offset ($end) that did not " +
+        s"originate with an instance of this class")
+    )
 
-class MQTTStreamSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging {
+    val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
 
-  override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType],
-      providerName: String, parameters: Map[String, String]): (String, StructType) = {
-    ("mqtt", MQTTStreamConstants.SCHEMA_DEFAULT)
+    if (offsetDiff < 0) {
+      sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
+    }
+
+    (lastOffsetCommitted.offset until newOffset.offset).foreach { x =>
+      messages.remove(x + 1)
+    }
+    lastOffsetCommitted = newOffset
+  }
+
+  /** Stop this source. */
+  override def stop(): Unit = synchronized {
+    client.disconnect()
+    persistence.close()
+    client.close()
   }
 
-  override def createSource(sqlContext: SQLContext, metadataPath: String,
-      schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = {
+  override def toString: String = s"MQTTStreamSource[brokerUrl: $brokerUrl, topic: $topic" +
+    s" clientId: $clientId]"
+}
+
+class MQTTStreamSourceProvider extends DataSourceV2
+  with MicroBatchReadSupport with DataSourceRegister with Logging {
 
+  override def createMicroBatchReader(schema: Optional[StructType],
+      checkpointLocation: String, parameters: DataSourceOptions): MicroBatchReader = {
     def e(s: String) = new IllegalArgumentException(s)
+    if (schema.isPresent) {
+      throw e("The mqtt source does not support a user-specified schema.")
+    }
 
-    val brokerUrl: String = parameters.getOrElse("brokerUrl", parameters.getOrElse("path",
-      throw e("Please provide a `brokerUrl` by specifying path or .options(\"brokerUrl\",...)")))
+    val brokerUrl = parameters.get("brokerUrl").orElse(parameters.get("path").orElse(null))
 
+    if (brokerUrl == null) {
+      throw e("Please provide a broker url, with option(\"brokerUrl\", ...).")
+    }
 
-    val persistence: MqttClientPersistence = parameters.get("persistence") match {
-      case Some("memory") => new MemoryPersistence()
-      case _ => val localStorage: Option[String] = parameters.get("localStorage")
+    val persistence: MqttClientPersistence = parameters.get("persistence").orElse("") match {
+      case "memory" => new MemoryPersistence()
+      case _ => val localStorage: String = parameters.get("localStorage").orElse("")
         localStorage match {
-          case Some(x) => new MqttDefaultFilePersistence(x)
-          case None => new MqttDefaultFilePersistence()
+          case "" => new MqttDefaultFilePersistence()
+          case x => new MqttDefaultFilePersistence(x)
         }
     }
 
-    val messageParserWithTimeStamp = (x: Array[Byte]) =>
-      (new String(x, Charset.defaultCharset()), Timestamp.valueOf(
-      MQTTStreamConstants.DATE_FORMAT.format(Calendar.getInstance().getTime)))
-
     // if default is subscribe everything, it leads to getting lot unwanted system messages.
-    val topic: String = parameters.getOrElse("topic",
-      throw e("Please specify a topic, by .options(\"topic\",...)"))
+    val topic: String = parameters.get("topic").orElse(null)
+    if (topic == null) {
+      throw e("Please specify a topic, by .options(\"topic\",...)")
+    }
 
-    val clientId: String = parameters.getOrElse("clientId", {
+    val clientId: String = parameters.get("clientId").orElse {
       log.warn("If `clientId` is not set, a random value is picked up." +
-        "\nRecovering from failure is not supported in such a case.")
-      MqttClient.generateClientId()})
+        " Recovering from failure is not supported in such a case.")
+      MqttClient.generateClientId()}
+
+    val username: String = parameters.get("username").orElse(null)
+    val password: String = parameters.get("password").orElse(null)
 
-    val username: Option[String] = parameters.get("username")
-    val password: Option[String] = parameters.get("password")
-    val connectionTimeout: Int = parameters.getOrElse("connectionTimeout",
+    val connectionTimeout: Int = parameters.get("connectionTimeout").orElse(
       MqttConnectOptions.CONNECTION_TIMEOUT_DEFAULT.toString).toInt
-    val keepAlive: Int = parameters.getOrElse("keepAlive", MqttConnectOptions
+    val keepAlive: Int = parameters.get("keepAlive").orElse(MqttConnectOptions
       .KEEP_ALIVE_INTERVAL_DEFAULT.toString).toInt
-    val mqttVersion: Int = parameters.getOrElse("mqttVersion", MqttConnectOptions
+    val mqttVersion: Int = parameters.get("mqttVersion").orElse(MqttConnectOptions
       .MQTT_VERSION_DEFAULT.toString).toInt
-    val cleanSession: Boolean = parameters.getOrElse("cleanSession", "false").toBoolean
-    val qos: Int = parameters.getOrElse("QoS", "1").toInt
-
+    val cleanSession: Boolean = parameters.get("cleanSession").orElse("true").toBoolean
+    val qos: Int = parameters.get("QoS").orElse("1").toInt
+    val autoReconnect: Boolean = parameters.get("autoReconnect").orElse("false").toBoolean
+    val maxInflight: Int = parameters.get("maxInflight").orElse("60").toInt
     val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions()
-    mqttConnectOptions.setAutomaticReconnect(true)
+    mqttConnectOptions.setAutomaticReconnect(autoReconnect)
     mqttConnectOptions.setCleanSession(cleanSession)
     mqttConnectOptions.setConnectionTimeout(connectionTimeout)
     mqttConnectOptions.setKeepAliveInterval(keepAlive)
     mqttConnectOptions.setMqttVersion(mqttVersion)
+    mqttConnectOptions.setMaxInflight(maxInflight)
     (username, password) match {
-      case (Some(u: String), Some(p: String)) =>
+      case (u: String, p: String) if u != null && p != null =>
         mqttConnectOptions.setUserName(u)
         mqttConnectOptions.setPassword(p.toCharArray)
       case _ =>
     }
 
-    new MQTTTextStreamSource(brokerUrl, persistence, topic, clientId,
-      messageParserWithTimeStamp, sqlContext, mqttConnectOptions, qos)
+    new  MQTTStreamSource(parameters, brokerUrl, persistence, topic, clientId,
+      mqttConnectOptions, qos)
   }
-
   override def shortName(): String = "mqtt"
 }

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
index 84fd8c4..d7d2657 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
@@ -18,15 +18,11 @@
 
 package org.apache.bahir.sql.streaming.mqtt
 
-import java.nio.ByteBuffer
+import java.io._
 import java.util
 
-import scala.reflect.ClassTag
-
 import org.eclipse.paho.client.mqttv3.{MqttClientPersistence, MqttPersistable, MqttPersistenceException}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerInstance}
+import scala.util.Try
 
 import org.apache.bahir.utils.Logging
 
@@ -35,16 +31,13 @@ import org.apache.bahir.utils.Logging
 trait MessageStore {
 
   /** Store a single id and corresponding serialized message */
-  def store[T: ClassTag](id: Int, message: T): Boolean
-
-  /** Retrieve messages corresponding to certain offset range */
-  def retrieve[T: ClassTag](start: Int, end: Int): Seq[T]
+  def store[T](id: Long, message: T): Boolean
 
   /** Retrieve message corresponding to a given id. */
-  def retrieve[T: ClassTag](id: Int): T
+  def retrieve[T](id: Long): T
 
   /** Highest offset we have stored */
-  def maxProcessedOffset: Int
+  def maxProcessedOffset: Long
 
 }
 
@@ -63,6 +56,52 @@ private[mqtt] class MqttPersistableData(bytes: Array[Byte]) extends MqttPersista
   override def getPayloadLength: Int = 0
 }
 
+trait Serializer {
+
+  def deserialize[T](x: Array[Byte]): T
+
+  def serialize[T](x: T): Array[Byte]
+}
+
+class JavaSerializer extends Serializer with Logging {
+
+  override def deserialize[T](x: Array[Byte]): T = {
+    val bis = new ByteArrayInputStream(x)
+    val in = new ObjectInputStream(bis)
+    val obj = if (in != null) {
+      val o = in.readObject()
+      Try(in.close()).recover { case t: Throwable => log.warn("failed to close stream", t) }
+      o
+    } else {
+      null
+    }
+    obj.asInstanceOf[T]
+  }
+
+  override def serialize[T](x: T): Array[Byte] = {
+    val bos = new ByteArrayOutputStream()
+    val out = new ObjectOutputStream(bos)
+    out.writeObject(x)
+    out.flush()
+    if (bos != null) {
+      val bytes: Array[Byte] = bos.toByteArray
+      Try(bos.close()).recover { case t: Throwable => log.warn("failed to close stream", t) }
+      bytes
+    } else {
+      null
+    }
+  }
+}
+
+object JavaSerializer {
+
+  private lazy val instance = new JavaSerializer()
+
+  def getInstance(): JavaSerializer = instance
+
+}
+
+
 /**
  * A message store to persist messages received. This is not intended to be thread safe.
  * It uses `MqttDefaultFilePersistence` for storing messages on disk locally on the client.
@@ -70,44 +109,35 @@ private[mqtt] class MqttPersistableData(bytes: Array[Byte]) extends MqttPersista
 private[mqtt] class LocalMessageStore(val persistentStore: MqttClientPersistence,
     val serializer: Serializer) extends MessageStore with Logging {
 
-  val classLoader = Thread.currentThread.getContextClassLoader
-
-  def this(persistentStore: MqttClientPersistence, conf: SparkConf) =
-    this(persistentStore, new JavaSerializer(conf))
+  def this(persistentStore: MqttClientPersistence) =
+    this(persistentStore, JavaSerializer.getInstance())
 
-  val serializerInstance: SerializerInstance = serializer.newInstance()
-
-  private def get(id: Int) = {
+  private def get(id: Long) = {
     persistentStore.get(id.toString).getHeaderBytes
   }
 
   import scala.collection.JavaConverters._
 
-  def maxProcessedOffset: Int = {
+  def maxProcessedOffset: Long = {
     val keys: util.Enumeration[_] = persistentStore.keys()
     keys.asScala.map(x => x.toString.toInt).max
   }
 
   /** Store a single id and corresponding serialized message */
-  override def store[T: ClassTag](id: Int, message: T): Boolean = {
-    val bytes: Array[Byte] = serializerInstance.serialize(message).array()
+  override def store[T](id: Long, message: T): Boolean = {
+    val bytes: Array[Byte] = serializer.serialize(message)
     try {
       persistentStore.put(id.toString, new MqttPersistableData(bytes))
       true
     } catch {
       case e: MqttPersistenceException => log.warn(s"Failed to store message Id: $id", e)
-      false
+        false
     }
   }
 
-  /** Retrieve messages corresponding to certain offset range */
-  override def retrieve[T: ClassTag](start: Int, end: Int): Seq[T] = {
-    (start until end).map(x => retrieve(x))
-  }
-
   /** Retrieve message corresponding to a given id. */
-  override def retrieve[T: ClassTag](id: Int): T = {
-    serializerInstance.deserialize(ByteBuffer.wrap(get(id)), classLoader)
+  override def retrieve[T](id: Long): T = {
+    serializer.deserialize(get(id))
   }
 
 }

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/test/bin/test-BAHIR-83.sh
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/test/bin/test-BAHIR-83.sh b/sql-streaming-mqtt/src/test/bin/test-BAHIR-83.sh
new file mode 100755
index 0000000..659dd8c
--- /dev/null
+++ b/sql-streaming-mqtt/src/test/bin/test-BAHIR-83.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+set -o pipefail
+
+for i in `seq 100` ; do
+  mvn scalatest:test -pl sql-streaming-mqtt -q -Dsuites='*.BasicMQTTSourceSuite' | \
+    grep -q "TEST FAILED" && echo "$i: failed"
+done

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
index 9c678cb..0a2a079 100644
--- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
+++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
@@ -22,8 +22,7 @@ import java.io.File
 import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
 import org.scalatest.BeforeAndAfter
 
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.SparkFunSuite
 
 import org.apache.bahir.utils.BahirUtils
 
@@ -31,9 +30,9 @@ import org.apache.bahir.utils.BahirUtils
 class LocalMessageStoreSuite extends SparkFunSuite with BeforeAndAfter {
 
   private val testData = Seq(1, 2, 3, 4, 5, 6)
-  private val javaSerializer: JavaSerializer = new JavaSerializer(new SparkConf())
+  private val javaSerializer: JavaSerializer = new JavaSerializer()
 
-  private val serializerInstance = javaSerializer.newInstance()
+  private val serializerInstance = javaSerializer
   private val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test2/")
   private val persistence: MqttDefaultFilePersistence =
     new MqttDefaultFilePersistence(tempDir.getAbsolutePath)
@@ -68,7 +67,7 @@ class LocalMessageStoreSuite extends SparkFunSuite with BeforeAndAfter {
   test("Max offset stored") {
     store.store(1, testData)
     store.store(10, testData)
-    val offset: Int = store.maxProcessedOffset
+    val offset = store.maxProcessedOffset
     assert(offset == 10)
   }
 

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
index 38971a0..2ce72da 100644
--- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
+++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
@@ -18,31 +18,34 @@
 package org.apache.bahir.sql.streaming.mqtt
 
 import java.io.File
-import java.sql.Timestamp
+import java.util.Optional
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
-import scala.concurrent.Future
 
 import org.eclipse.paho.client.mqttv3.MqttException
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.{SharedSparkContext, SparkFunSuite}
-import org.apache.spark.sql.{DataFrame, SQLContext}
-import org.apache.spark.sql.execution.streaming.LongOffset
+import org.apache.spark.sql._
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQuery}
 
 import org.apache.bahir.utils.BahirUtils
 
-
 class MQTTStreamSourceSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter {
 
   protected var mqttTestUtils: MQTTTestUtils = _
   protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test/")
 
   before {
+    tempDir.mkdirs()
+    if (!tempDir.exists()) {
+      throw new IllegalStateException("Unable to create temp directories.")
+    }
+    tempDir.deleteOnExit()
     mqttTestUtils = new MQTTTestUtils(tempDir)
     mqttTestUtils.setup()
-    tempDir.mkdirs()
   }
 
   after {
@@ -52,16 +55,44 @@ class MQTTStreamSourceSuite extends SparkFunSuite with SharedSparkContext with B
 
   protected val tmpDir: String = tempDir.getAbsolutePath
 
-  protected def createStreamingDataframe(dir: String = tmpDir): (SQLContext, DataFrame) = {
+  protected def writeStreamResults(sqlContext: SQLContext, dataFrame: DataFrame): StreamingQuery = {
+    import sqlContext.implicits._
+    val query: StreamingQuery = dataFrame.selectExpr("CAST(payload AS STRING)").as[String]
+      .writeStream.format("parquet").start(s"$tmpDir/t.parquet")
+    while (!query.status.isTriggerActive) {
+      Thread.sleep(20)
+    }
+    query
+  }
+
+  protected def readBackStreamingResults(sqlContext: SQLContext): mutable.Buffer[String] = {
+    import sqlContext.implicits._
+    val asList =
+      sqlContext.read
+        .parquet(s"$tmpDir/t.parquet").as[String]
+        .collectAsList().asScala
+    asList
+  }
+
+  protected def createStreamingDataframe(dir: String = tmpDir,
+      filePersistence: Boolean = false): (SQLContext, DataFrame) = {
 
     val sqlContext: SQLContext = new SQLContext(sc)
 
     sqlContext.setConf("spark.sql.streaming.checkpointLocation", tmpDir)
 
-    val dataFrame: DataFrame =
+    val ds: DataStreamReader =
       sqlContext.readStream.format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
-        .option("topic", "test").option("localStorage", dir).option("clientId", "clientId")
-        .option("QoS", "2").load("tcp://" + mqttTestUtils.brokerUri)
+        .option("topic", "test").option("clientId", "clientId").option("connectionTimeout", "120")
+        .option("keepAlive", "1200").option("maxInflight", "120").option("autoReconnect", "false")
+        .option("cleanSession", "true").option("QoS", "2")
+
+    val dataFrame = if (!filePersistence) {
+      ds.option("persistence", "memory").load("tcp://" + mqttTestUtils.brokerUri)
+    } else {
+      ds.option("persistence", "file").option("localStorage", tmpDir)
+        .load("tcp://" + mqttTestUtils.brokerUri)
+    }
     (sqlContext, dataFrame)
   }
 
@@ -69,31 +100,16 @@ class MQTTStreamSourceSuite extends SparkFunSuite with SharedSparkContext with B
 
 class BasicMQTTSourceSuite extends MQTTStreamSourceSuite {
 
-  private def writeStreamResults(sqlContext: SQLContext,
-      dataFrame: DataFrame, waitDuration: Long): Boolean = {
-    import sqlContext.implicits._
-    dataFrame.as[(String, Timestamp)].writeStream.format("parquet").start(s"$tmpDir/t.parquet")
-      .awaitTermination(waitDuration)
-  }
-
-  private def readBackStreamingResults(sqlContext: SQLContext): mutable.Buffer[String] = {
-    import sqlContext.implicits._
-    val asList =
-      sqlContext.read.schema(MQTTStreamConstants.SCHEMA_DEFAULT)
-        .parquet(s"$tmpDir/t.parquet").as[(String, Timestamp)].map(_._1)
-        .collectAsList().asScala
-    asList
-  }
-
   test("basic usage") {
 
     val sendMessage = "MQTT is a message queue."
 
-    mqttTestUtils.publishData("test", sendMessage)
-
     val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe()
 
-    writeStreamResults(sqlContext, dataFrame, 5000)
+    val query = writeStreamResults(sqlContext, dataFrame)
+    mqttTestUtils.publishData("test", sendMessage)
+    query.processAllAvailable()
+    query.awaitTermination(10000)
 
     val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext)
 
@@ -101,88 +117,58 @@ class BasicMQTTSourceSuite extends MQTTStreamSourceSuite {
     assert(resultBuffer.head == sendMessage)
   }
 
-  // TODO: reinstate this test after fixing BAHIR-83
-  ignore("Send and receive 100 messages.") {
+  test("Send and receive 50 messages.") {
 
     val sendMessage = "MQTT is a message queue."
 
-    import scala.concurrent.ExecutionContext.Implicits.global
-
     val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe()
 
-    Future {
-      Thread.sleep(2000)
-      mqttTestUtils.publishData("test", sendMessage, 100)
-    }
+    val q = writeStreamResults(sqlContext, dataFrame)
 
-    writeStreamResults(sqlContext, dataFrame, 10000)
+    mqttTestUtils.publishData("test", sendMessage, 50)
+    q.processAllAvailable()
+    q.awaitTermination(10000)
 
     val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext)
 
-    assert(resultBuffer.size == 100)
+    assert(resultBuffer.size == 50)
     assert(resultBuffer.head == sendMessage)
   }
 
   test("no server up") {
     val provider = new MQTTStreamSourceProvider
     val sqlContext: SQLContext = new SQLContext(sc)
-    val parameters = Map("brokerUrl" -> "tcp://localhost:1883", "topic" -> "test",
-      "localStorage" -> tmpDir)
+    val parameters = new DataSourceOptions(Map("brokerUrl" ->
+      "tcp://localhost:1881", "topic" -> "test", "localStorage" -> tmpDir).asJava)
     intercept[MqttException] {
-      provider.createSource(sqlContext, "", None, "", parameters)
+      provider.createMicroBatchReader(Optional.empty(), tempDir.toString, parameters)
     }
   }
 
   test("params not provided.") {
     val provider = new MQTTStreamSourceProvider
-    val sqlContext: SQLContext = new SQLContext(sc)
-    val parameters = Map("brokerUrl" -> mqttTestUtils.brokerUri,
-      "localStorage" -> tmpDir)
+    val parameters = new DataSourceOptions(Map("brokerUrl" -> mqttTestUtils.brokerUri,
+      "localStorage" -> tmpDir).asJava)
     intercept[IllegalArgumentException] {
-      provider.createSource(sqlContext, "", None, "", parameters)
+      provider.createMicroBatchReader(Optional.empty(), tempDir.toString, parameters)
     }
     intercept[IllegalArgumentException] {
-      provider.createSource(sqlContext, "", None, "", Map())
+      provider.createMicroBatchReader(Optional.empty(), tempDir.toString, DataSourceOptions.empty())
     }
   }
 
-  // TODO: reinstate this test after fixing BAHIR-83
-  ignore("Recovering offset from the last processed offset.") {
-    val sendMessage = "MQTT is a message queue."
-
-    import scala.concurrent.ExecutionContext.Implicits.global
-
-    val (sqlContext: SQLContext, dataFrame: DataFrame) =
-      createStreamingDataframe()
-
-    Future {
-      Thread.sleep(2000)
-      mqttTestUtils.publishData("test", sendMessage, 100)
-    }
-
-    writeStreamResults(sqlContext, dataFrame, 10000)
-    // On restarting the source with same params, it should begin from the offset - the
-    // previously running stream left off.
-    val provider = new MQTTStreamSourceProvider
-    val parameters = Map("brokerUrl" -> ("tcp://" + mqttTestUtils.brokerUri), "topic" -> "test",
-      "localStorage" -> tmpDir, "clientId" -> "clientId", "QoS" -> "2")
-    val offset: Long = provider.createSource(sqlContext, "", None, "", parameters)
-      .getOffset.get.asInstanceOf[LongOffset].offset
-    assert(offset == 100L)
-  }
-
 }
 
 class StressTestMQTTSource extends MQTTStreamSourceSuite {
 
   // Run with -Xmx1024m
-  ignore("Send and receive messages of size 250MB.") {
+  test("Send and receive messages of size 100MB.") {
 
     val freeMemory: Long = Runtime.getRuntime.freeMemory()
 
     log.info(s"Available memory before test run is ${freeMemory / (1024 * 1024)}MB.")
 
-    val noOfMsgs = (250 * 1024 * 1024) / (500 * 1024) // 512
+    val noOfMsgs: Int = (100 * 1024 * 1024) / (500 * 1024) // 204
 
     val messageBuilder = new StringBuilder()
     for (i <- 0 until (500 * 1024)) yield messageBuilder.append(((i % 26) + 65).toChar)
@@ -190,22 +176,14 @@ class StressTestMQTTSource extends MQTTStreamSourceSuite {
 
     val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe()
 
-    import scala.concurrent.ExecutionContext.Implicits.global
-    Future {
-      Thread.sleep(2000)
-      mqttTestUtils.publishData("test", sendMessage, noOfMsgs.toInt)
-    }
-
-    import sqlContext.implicits._
-
-    dataFrame.as[(String, Timestamp)].writeStream
-      .format("parquet")
-      .start(s"$tmpDir/t.parquet")
-      .awaitTermination(25000)
+    val query = writeStreamResults(sqlContext, dataFrame)
+    mqttTestUtils.publishData("test", sendMessage, noOfMsgs )
+    query.processAllAvailable()
+    query.awaitTermination(25000)
 
     val messageCount =
-      sqlContext.read.schema(MQTTStreamConstants.SCHEMA_DEFAULT)
-        .parquet(s"$tmpDir/t.parquet").as[(String, Timestamp)].map(_._1)
+      sqlContext.read
+        .parquet(s"$tmpDir/t.parquet")
         .count()
     assert(messageCount == noOfMsgs)
   }

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
index 9c7399f..817ec9a 100644
--- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
+++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
@@ -22,15 +22,14 @@ import java.net.{ServerSocket, URI}
 
 import org.apache.activemq.broker.{BrokerService, TransportConnector}
 import org.eclipse.paho.client.mqttv3._
-import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+import org.eclipse.paho.client.mqttv3.persist.{MemoryPersistence, MqttDefaultFilePersistence}
 
 import org.apache.bahir.utils.Logging
 
 
 class MQTTTestUtils(tempDir: File, port: Int = 0) extends Logging {
 
-  private val persistenceDir = tempDir.getAbsolutePath
-  private val brokerHost = "localhost"
+  private val brokerHost = "127.0.0.1"
   private val brokerPort: Int = if (port == 0) findFreePort() else port
 
   private var broker: BrokerService = _
@@ -60,18 +59,21 @@ class MQTTTestUtils(tempDir: File, port: Int = 0) extends Logging {
   def teardown(): Unit = {
     if (broker != null) {
       broker.stop()
-      broker = null
     }
     if (connector != null) {
       connector.stop()
       connector = null
     }
+    while (!broker.isStopped) {
+      Thread.sleep(50)
+    }
+    broker = null
   }
 
   def publishData(topic: String, data: String, N: Int = 1): Unit = {
     var client: MqttClient = null
     try {
-      val persistence = new MqttDefaultFilePersistence(persistenceDir)
+      val persistence = new MemoryPersistence()
       client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence)
       client.connect()
       if (client.isConnected) {
@@ -81,7 +83,7 @@ class MQTTTestUtils(tempDir: File, port: Int = 0) extends Logging {
             Thread.sleep(20)
             val message = new MqttMessage(data.getBytes())
             message.setQos(2)
-            message.setRetained(true)
+            // message.setId(i) setting id has no effect.
             msgTopic.publish(message)
           } catch {
             case e: MqttException =>