You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@bahir.apache.org by lr...@apache.org on 2018/12/15 21:39:37 UTC
[2/3] bahir git commit: [BAHIR-183] Using HDFS for saving message for
mqtt source.
[BAHIR-183] Using HDFS for saving message for mqtt source.
Closes #78
Project: http://git-wip-us.apache.org/repos/asf/bahir/repo
Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/172d7096
Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/172d7096
Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/172d7096
Branch: refs/heads/master
Commit: 172d7096147cd0be70687af893a4d71380ce47bf
Parents: 63878bf
Author: wangyanlin01 <wa...@baidu.com>
Authored: Sun Dec 2 11:00:21 2018 +0800
Committer: Luciano Resende <lr...@apache.org>
Committed: Sat Dec 15 18:39:19 2018 -0300
----------------------------------------------------------------------
sql-streaming-mqtt/pom.xml | 12 +
....apache.spark.sql.sources.DataSourceRegister | 3 +-
.../sql/streaming/mqtt/CachedMQTTClient.scala | 2 +-
.../sql/streaming/mqtt/MQTTStreamSink.scala | 2 +-
.../sql/streaming/mqtt/MQTTStreamSource.scala | 2 +-
.../bahir/sql/streaming/mqtt/MQTTUtils.scala | 15 +-
.../spark/sql/mqtt/HDFSMQTTSourceProvider.scala | 64 +++
.../sql/mqtt/HdfsBasedMQTTStreamSource.scala | 400 +++++++++++++++++++
.../mqtt/HDFSBasedMQTTStreamSourceSuite.scala | 198 +++++++++
9 files changed, 689 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/pom.xml
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/pom.xml b/sql-streaming-mqtt/pom.xml
index 63497dc..05a3fff 100644
--- a/sql-streaming-mqtt/pom.xml
+++ b/sql-streaming-mqtt/pom.xml
@@ -85,6 +85,18 @@
<version>5.13.3</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-hdfs</artifactId>
+ <version>2.6.5</version>
+ <classifier>tests</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-common</artifactId>
+ <version>2.6.5</version>
+ <classifier>tests</classifier>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index d3899e6..1920a6b 100644
--- a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -16,4 +16,5 @@
#
org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider
-org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider
\ No newline at end of file
+org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider
+org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala
index fed2601..8925e93 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala
@@ -66,7 +66,7 @@ private[mqtt] object CachedMQTTClient extends Logging {
private def createMqttClient(config: Map[String, String]):
(MqttClient, MqttClientPersistence) = {
- val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _) =
+ val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _, _, _, _) =
MQTTUtils.parseConfigParams(config)
val client = new MqttClient(brokerUrl, clientId, persistence)
val callback = new MqttCallbackExtended() {
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala
index f449e57..846765c 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala
@@ -52,7 +52,7 @@ class MQTTStreamWriter (schema: StructType, parameters: DataSourceOptions)
initialize()
private def initialize(): Unit = {
- val (_, _, topic_, _, _, qos_) = MQTTUtils.parseConfigParams(
+ val (_, _, topic_, _, _, qos_, _, _, _) = MQTTUtils.parseConfigParams(
collection.immutable.HashMap() ++ parameters.asMap().asScala
)
topic = topic_
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
index 98bc60e..a40ff51 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
@@ -244,7 +244,7 @@ class MQTTStreamSourceProvider extends DataSourceV2
}
import scala.collection.JavaConverters._
- val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos) =
+ val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos, _, _, _) =
MQTTUtils.parseConfigParams(collection.immutable.HashMap() ++ parameters.asMap().asScala)
new MQTTStreamSource(parameters, brokerUrl, persistence, topic, clientId,
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala
index f0a6f1a..9df46bc 100644
--- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala
@@ -26,8 +26,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.bahir.utils.Logging
-
-private[mqtt] object MQTTUtils extends Logging {
+object MQTTUtils extends Logging {
// Since data source configuration properties are case-insensitive,
// we have to introduce our own keys. Also, good for vendor independence.
private[mqtt] val sslParamMapping = Map(
@@ -45,8 +44,8 @@ private[mqtt] object MQTTUtils extends Logging {
"ssl.trust.manager" -> "com.ibm.ssl.trustManager"
)
- private[mqtt] def parseConfigParams(config: Map[String, String]):
- (String, String, String, MqttClientPersistence, MqttConnectOptions, Int) = {
+ def parseConfigParams(config: Map[String, String]):
+ (String, String, String, MqttClientPersistence, MqttConnectOptions, Int, Long, Long, Int) = {
def e(s: String) = new IllegalArgumentException(s)
val parameters = CaseInsensitiveMap(config)
@@ -84,6 +83,11 @@ private[mqtt] object MQTTUtils extends Logging {
val autoReconnect: Boolean = parameters.getOrElse("autoReconnect", "false").toBoolean
val maxInflight: Int = parameters.getOrElse("maxInflight", "60").toInt
+ val maxBatchMessageNum = parameters.getOrElse("maxBatchMessageNum", s"${Long.MaxValue}").toLong
+ val maxBatchMessageSize = parameters.getOrElse("maxBatchMessageSize",
+ s"${Long.MaxValue}").toLong
+ val maxRetryNumber = parameters.getOrElse("maxRetryNum", "3").toInt
+
val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions()
mqttConnectOptions.setAutomaticReconnect(autoReconnect)
mqttConnectOptions.setCleanSession(cleanSession)
@@ -105,6 +109,7 @@ private[mqtt] object MQTTUtils extends Logging {
})
mqttConnectOptions.setSSLProperties(sslProperties)
- (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos)
+ (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos,
+ maxBatchMessageNum, maxBatchMessageSize, maxRetryNumber)
}
}
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala
new file mode 100644
index 0000000..f38d842
--- /dev/null
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.mqtt
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.streaming.Source
+import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
+import org.apache.spark.sql.types.StructType
+
+import org.apache.bahir.sql.streaming.mqtt.{MQTTStreamConstants, MQTTUtils}
+
+/**
+ * The provider class for creating MQTT source.
+ * This provider throw IllegalArgumentException if 'brokerUrl' or 'topic' parameter
+ * is not set in options.
+ */
+class HDFSMQTTSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging {
+
+ override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType],
+ providerName: String, parameters: Map[String, String]): (String, StructType) = {
+ ("hdfs-mqtt", MQTTStreamConstants.SCHEMA_DEFAULT)
+ }
+
+ override def createSource(sqlContext: SQLContext, metadataPath: String,
+ schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = {
+
+ val parsedResult = MQTTUtils.parseConfigParams(parameters)
+
+ new HdfsBasedMQTTStreamSource(
+ sqlContext,
+ metadataPath,
+ parsedResult._1, // brokerUrl
+ parsedResult._2, // clientId
+ parsedResult._3, // topic
+ parsedResult._5, // mqttConnectionOptions
+ parsedResult._6, // qos
+ parsedResult._7, // maxBatchMessageNum
+ parsedResult._8, // maxBatchMessageSize
+ parsedResult._9 // maxRetryNum
+ )
+ }
+
+ override def shortName(): String = "hdfs-mqtt"
+}
+
+object HDFSMQTTSourceProvider {
+ val SEP = "##"
+}
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala
new file mode 100644
index 0000000..e6e202b
--- /dev/null
+++ b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.mqtt
+
+import java.io.IOException
+import java.sql.Timestamp
+import java.util.Calendar
+import java.util.concurrent.locks.{Lock, ReentrantLock}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path, PathFilter}
+import org.eclipse.paho.client.mqttv3._
+import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.FileContextManager
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+
+import org.apache.bahir.sql.streaming.mqtt.{LongOffset, MQTTStreamConstants}
+
+/**
+ * A Text based mqtt stream source, it interprets the payload of each incoming message by converting
+ * the bytes to String using Charset.defaultCharset as charset. Each value is associated with a
+ * timestamp of arrival of the message on the source. It can be used to operate a window on the
+ * incoming stream.
+ *
+ * @param sqlContext Spark provided, SqlContext.
+ * @param metadataPath meta data path
+ * @param brokerUrl url MqttClient connects to.
+ * @param topic topic MqttClient subscribes to.
+ * @param clientId clientId, this client is assoicated with.
+ * Provide the same value to recover a stopped client.
+ * @param mqttConnectOptions an instance of MqttConnectOptions for this Source.
+ * @param qos the maximum quality of service to subscribe each topic at.
+ * Messages published at a lower quality of service will be received
+ * at the published QoS. Messages published at a higher quality of
+ * service will be received using the QoS specified on the subscribe.
+ * @param maxBatchNumber the max message number to process in one batch.
+ * @param maxBatchSize the max total size in one batch, measured in bytes number.
+ */
+class HdfsBasedMQTTStreamSource(
+ sqlContext: SQLContext,
+ metadataPath: String,
+ brokerUrl: String,
+ clientId: String,
+ topic: String,
+ mqttConnectOptions: MqttConnectOptions,
+ qos: Int,
+ maxBatchNumber: Long = Long.MaxValue,
+ maxBatchSize: Long = Long.MaxValue,
+ maxRetryNumber: Int = 3
+) extends Source with Logging {
+
+ import HDFSMQTTSourceProvider.SEP
+
+ override def schema: StructType = MQTTStreamConstants.SCHEMA_DEFAULT
+
+ // Last batch offset file index
+ private var lastOffset: Long = -1L
+
+ // Current data file index to write messages.
+ private var currentMessageDataFileOffset: Long = 0L
+
+ // FileSystem instance for storing received messages.
+ private var fs: FileSystem = _
+ private var messageStoreOutputStream: FSDataOutputStream = _
+
+ // total message number received for current batch.
+ private var messageNumberForCurrentBatch: Int = 0
+ // total message size received for
+ private var messageSizeForCurrentBatch: Int = 0
+
+ private val minBatchesToRetain = sqlContext.sparkSession.sessionState.conf.minBatchesToRetain
+
+ // the consecutive fail number, cannot exceed the `maxRetryNumber`
+ private var consecutiveFailNum = 0
+
+ private var client: MqttClient = _
+
+ private val lock: Lock = new ReentrantLock()
+
+ private val hadoopConfig: Configuration = if (HdfsBasedMQTTStreamSource.hadoopConfig != null) {
+ logInfo("using setted hadoop configuration!")
+ HdfsBasedMQTTStreamSource.hadoopConfig
+ } else {
+ logInfo("create a new configuration.")
+ new Configuration()
+ }
+
+ private val rootCheckpointPath = {
+ val path = new Path(metadataPath).getParent.getParent.toUri.toString
+ logInfo(s"get rootCheckpointPath $path")
+ path
+ }
+
+ private val receivedDataPath = s"$rootCheckpointPath/receivedMessages"
+
+ // lazily init latest offset from offset WAL log
+ private lazy val recoveredLatestOffset = {
+ // the index of this source, parsing from metadata path
+ val currentSourceIndex = {
+ if (!metadataPath.isEmpty) {
+ metadataPath.substring(metadataPath.lastIndexOf("/") + 1).toInt
+ } else {
+ -1
+ }
+ }
+ if (currentSourceIndex >= 0) {
+ val offsetLog = new OffsetSeqLog(sqlContext.sparkSession,
+ new Path(rootCheckpointPath, "offsets").toUri.toString)
+ // get the latest offset from WAL log
+ offsetLog.getLatest() match {
+ case Some((batchId, _)) =>
+ logInfo(s"get latest batch $batchId")
+ Some(batchId)
+ case None =>
+ logInfo("no offset avaliable in offset log")
+ None
+ }
+ } else {
+ logInfo("checkpoint path is not set")
+ None
+ }
+ }
+
+ initialize()
+
+ // Change data file if reach flow control threshold for one batch.
+ // Not thread safe.
+ private def startWriteNewDataFile(): Unit = {
+ if (messageStoreOutputStream != null) {
+ logInfo(s"Need to write a new data file,"
+ + s" close current data file index $currentMessageDataFileOffset")
+ messageStoreOutputStream.flush()
+ messageStoreOutputStream.hsync()
+ messageStoreOutputStream.close()
+ messageStoreOutputStream = null
+ }
+ currentMessageDataFileOffset += 1
+ messageSizeForCurrentBatch = 0
+ messageNumberForCurrentBatch = 0
+ messageStoreOutputStream = null
+ }
+
+ // not thread safe
+ private def addReceivedMessageInfo(messageNum: Int, messageSize: Int): Unit = {
+ messageSizeForCurrentBatch += messageSize
+ messageNumberForCurrentBatch += messageNum
+ }
+
+ // not thread safe
+ private def hasNewMessageForCurrentBatch(): Boolean = {
+ currentMessageDataFileOffset > lastOffset + 1 || messageNumberForCurrentBatch > 0
+ }
+
+ private def withLock[T](body: => T): T = {
+ lock.lock()
+ try body
+ finally lock.unlock()
+ }
+
+ private def initialize(): Unit = {
+
+ // recover lastOffset from WAL log
+ if (recoveredLatestOffset.nonEmpty) {
+ lastOffset = recoveredLatestOffset.get
+ logInfo(s"Recover lastOffset value ${lastOffset}")
+ }
+
+ fs = FileSystem.get(hadoopConfig)
+
+ // recover message data file offset from hdfs
+ val dataPath = new Path(receivedDataPath)
+ if (fs.exists(dataPath)) {
+ val fileManager = new FileContextManager(dataPath, hadoopConfig)
+ val dataFileIndexs = fileManager.list(dataPath, new PathFilter {
+ private def isBatchFile(path: Path) = {
+ try {
+ path.getName.toLong
+ true
+ } catch {
+ case _: NumberFormatException => false
+ }
+ }
+
+ override def accept(path: Path): Boolean = isBatchFile(path)
+ }).map(_.getPath.getName.toLong)
+ if (dataFileIndexs.nonEmpty) {
+ currentMessageDataFileOffset = dataFileIndexs.max + 1
+ assert(currentMessageDataFileOffset >= lastOffset + 1,
+ s"Recovered invalid message data file offset $currentMessageDataFileOffset,"
+ + s"do not match with lastOffset $lastOffset")
+ logInfo(s"Recovered last message data file offset: ${currentMessageDataFileOffset - 1}, "
+ + s"start from $currentMessageDataFileOffset")
+ } else {
+ logInfo("No old data file exist, start data file index from 0")
+ currentMessageDataFileOffset = 0
+ }
+ } else {
+ logInfo(s"Create data dir $receivedDataPath, start data file index from 0")
+ fs.mkdirs(dataPath)
+ currentMessageDataFileOffset = 0
+ }
+
+ client = new MqttClient(brokerUrl, clientId, new MemoryPersistence())
+
+ val callback = new MqttCallbackExtended() {
+
+ override def messageArrived(topic: String, message: MqttMessage): Unit = {
+ withLock[Unit] {
+ val messageSize = message.getPayload.size
+ // check if have reached the max number or max size for current batch.
+ if (messageNumberForCurrentBatch + 1 > maxBatchNumber
+ || messageSizeForCurrentBatch + messageSize > maxBatchSize) {
+ startWriteNewDataFile()
+ }
+ // write message content to data file
+ if (messageStoreOutputStream == null) {
+ val path = new Path(s"${receivedDataPath}/${currentMessageDataFileOffset}")
+ if (fs.createNewFile(path)) {
+ logInfo(s"Create new message data file ${path.toUri.toString} success!")
+ } else {
+ throw new IOException(s"${path.toUri.toString} already exist,"
+ + s"make sure do use unique checkpoint path for each app.")
+ }
+ messageStoreOutputStream = fs.append(path)
+ }
+
+ messageStoreOutputStream.writeBytes(s"${message.getId}${SEP}")
+ messageStoreOutputStream.writeBytes(s"${topic}${SEP}")
+ val timestamp = Calendar.getInstance().getTimeInMillis().toString
+ messageStoreOutputStream.writeBytes(s"${timestamp}${SEP}")
+ messageStoreOutputStream.write(message.getPayload())
+ messageStoreOutputStream.writeBytes("\n")
+ addReceivedMessageInfo(1, messageSize)
+ consecutiveFailNum = 0
+ logInfo(s"Message arrived, topic: $topic, message payload $message, "
+ + s"messageId: ${message.getId}, message size: ${messageSize}")
+ }
+ }
+
+ override def deliveryComplete(token: IMqttDeliveryToken): Unit = {
+ // callback for publisher, no need here.
+ }
+
+ override def connectionLost(cause: Throwable): Unit = {
+ // auto reconnection is enabled, so just add a log here.
+ withLock[Unit] {
+ consecutiveFailNum += 1
+ logWarning(s"Connection to mqtt server lost, "
+ + s"consecutive fail number $consecutiveFailNum", cause)
+ }
+ }
+
+ override def connectComplete(reconnect: Boolean, serverURI: String): Unit = {
+ logInfo(s"Connect complete $serverURI. Is it a reconnect?: $reconnect")
+ }
+ }
+ client.setCallback(callback)
+ client.connect(mqttConnectOptions)
+ client.subscribe(topic, qos)
+ }
+
+ /** Stop this source and free any resources it has allocated. */
+ override def stop(): Unit = {
+ logInfo("Stop mqtt source.")
+ client.disconnect()
+ client.close()
+ withLock[Unit] {
+ if (messageStoreOutputStream != null) {
+ messageStoreOutputStream.hflush()
+ messageStoreOutputStream.hsync()
+ messageStoreOutputStream.close()
+ messageStoreOutputStream = null
+ }
+ fs.close()
+ }
+ }
+
+ /** Returns the maximum available offset for this source. */
+ override def getOffset: Option[Offset] = {
+ withLock[Option[Offset]] {
+ assert(consecutiveFailNum < maxRetryNumber,
+ s"Write message data fail continuously for ${maxRetryNumber} times.")
+ val result = if (!hasNewMessageForCurrentBatch()) {
+ if (lastOffset == -1) {
+ // first submit and no message has arrived.
+ None
+ } else {
+ // no message has arrived for this batch.
+ Some(LongOffset(lastOffset))
+ }
+ } else {
+ // check if currently write the batch to be executed.
+ if (currentMessageDataFileOffset == lastOffset + 1) {
+ startWriteNewDataFile()
+ }
+ lastOffset += 1
+ Some(LongOffset(lastOffset))
+ }
+ logInfo(s"getOffset result $result")
+ result
+ }
+ }
+
+ /**
+ * Returns the data that is between the offsets (`start`, `end`].
+ * The batch return the data in file ${checkpointPath}/receivedMessages/${end}.
+ * `Start` and `end` value have the relationship: `end value` = `start valud` + 1,
+ * if `start` is not None.
+ */
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ withLock[Unit]{
+ assert(consecutiveFailNum < maxRetryNumber,
+ s"Write message data fail continuously for ${maxRetryNumber} times.")
+ }
+ logInfo(s"getBatch with start = $start, end = $end")
+ val endIndex = getOffsetValue(end)
+ if (start.nonEmpty) {
+ val startIndex = getOffsetValue(start.get)
+ assert(startIndex + 1 == endIndex,
+ s"start offset: ${startIndex} and end offset: ${endIndex} do not match")
+ }
+ logTrace(s"Create a data frame using hdfs file $receivedDataPath/$endIndex")
+ val rdd = sqlContext.sparkContext.textFile(s"$receivedDataPath/$endIndex")
+ .map{case str =>
+ // calculate message in
+ val idIndex = str.indexOf(SEP)
+ val messageId = str.substring(0, idIndex).toInt
+ // get topic
+ var subStr = str.substring(idIndex + SEP.length)
+ val topicIndex = subStr.indexOf(SEP)
+ val topic = UTF8String.fromString(subStr.substring(0, topicIndex))
+ // get timestamp
+ subStr = subStr.substring(topicIndex + SEP.length)
+ val timestampIndex = subStr.indexOf(SEP)
+ /*
+ val timestamp = Timestamp.valueOf(
+ MQTTStreamConstants.DATE_FORMAT.format(subStr.substring(0, timestampIndex).toLong))
+ */
+ val timestamp = subStr.substring(0, timestampIndex).toLong
+ // get playload
+ subStr = subStr.substring(timestampIndex + SEP.length)
+ val payload = UTF8String.fromString(subStr).getBytes
+ InternalRow(messageId, topic, payload, timestamp)
+ }
+ sqlContext.internalCreateDataFrame(rdd, MQTTStreamConstants.SCHEMA_DEFAULT, true)
+ }
+
+ /**
+ * Remove the data file for the offset.
+ *
+ * @param end the end of offset that all data has been committed.
+ */
+ override def commit(end: Offset): Unit = {
+ val offsetValue = getOffsetValue(end)
+ if (offsetValue >= minBatchesToRetain) {
+ val deleteDataFileOffset = offsetValue - minBatchesToRetain
+ try {
+ fs.delete(new Path(s"$receivedDataPath/$deleteDataFileOffset"), false)
+ logInfo(s"Delete committed offset data file $deleteDataFileOffset success!")
+ } catch {
+ case e: Exception =>
+ logWarning(s"Delete committed offset data file $deleteDataFileOffset failed. ", e)
+ }
+ }
+ }
+
+ private def getOffsetValue(offset: Offset): Long = {
+ val offsetValue = offset match {
+ case o: LongOffset => o.offset
+ case so: SerializedOffset =>
+ so.json.toLong
+ }
+ offsetValue
+ }
+}
+object HdfsBasedMQTTStreamSource {
+
+ var hadoopConfig: Configuration = _
+}
http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala
new file mode 100644
index 0000000..777db16
--- /dev/null
+++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.sql.streaming.mqtt
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hdfs.MiniDFSCluster
+import org.apache.hadoop.security.Groups
+import org.eclipse.paho.client.mqttv3.MqttException
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
+import org.apache.spark.sql._
+import org.apache.spark.sql.mqtt.{HdfsBasedMQTTStreamSource, HDFSMQTTSourceProvider}
+import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQuery}
+
+import org.apache.bahir.utils.FileHelper
+
+class HDFSBasedMQTTStreamSourceSuite
+ extends SparkFunSuite
+ with SharedSparkContext
+ with BeforeAndAfter {
+
+ protected var mqttTestUtils: MQTTTestUtils = _
+ protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test/")
+ protected var hadoop: MiniDFSCluster = _
+
+ before {
+ tempDir.mkdirs()
+ if (!tempDir.exists()) {
+ throw new IllegalStateException("Unable to create temp directories.")
+ }
+ tempDir.deleteOnExit()
+ mqttTestUtils = new MQTTTestUtils(tempDir)
+ mqttTestUtils.setup()
+ hadoop = HDFSTestUtils.prepareHadoop()
+ }
+
+ after {
+ mqttTestUtils.teardown()
+ HDFSTestUtils.shutdownHadoop()
+ FileHelper.deleteFileQuietly(tempDir)
+ }
+
+ protected val tmpDir: String = tempDir.getAbsolutePath
+
+ protected def writeStreamResults(sqlContext: SQLContext, dataFrame: DataFrame): StreamingQuery = {
+ import sqlContext.implicits._
+ val query: StreamingQuery = dataFrame.selectExpr("CAST(payload AS STRING)").as[String]
+ .writeStream.format("csv").start(s"$tempDir/t.csv")
+ while (!query.status.isTriggerActive) {
+ Thread.sleep(20)
+ }
+ query
+ }
+
+ protected def readBackStreamingResults(sqlContext: SQLContext): mutable.Buffer[String] = {
+ import sqlContext.implicits._
+ val asList =
+ sqlContext.read
+ .csv(s"$tmpDir/t.csv").as[String]
+ .collectAsList().asScala
+ asList
+ }
+
+ protected def createStreamingDataFrame(dir: String = tmpDir): (SQLContext, DataFrame) = {
+
+ val sqlContext: SQLContext = SparkSession.builder()
+ .getOrCreate().sqlContext
+
+ sqlContext.setConf("spark.sql.streaming.checkpointLocation",
+ s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint")
+
+ val ds: DataStreamReader =
+ sqlContext.readStream.format("org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider")
+ .option("topic", "test").option("clientId", "clientId").option("connectionTimeout", "120")
+ .option("keepAlive", "1200").option("autoReconnect", "false")
+ .option("cleanSession", "true").option("QoS", "2")
+ val dataFrame = ds.load("tcp://" + mqttTestUtils.brokerUri)
+ (sqlContext, dataFrame)
+ }
+}
+
+object HDFSTestUtils {
+
+ private var hadoop: MiniDFSCluster = _
+
+ def prepareHadoop(): MiniDFSCluster = {
+ if (hadoop != null) {
+ hadoop
+ } else {
+ val baseDir = new File(System.getProperty("java.io.tmpdir") + "/hadoop").getAbsoluteFile
+ System.setProperty("HADOOP_USER_NAME", "test")
+ val conf = new Configuration
+ conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath)
+ conf.setBoolean("dfs.namenode.acls.enabled", true)
+ conf.setBoolean("dfs.permissions", true)
+ Groups.getUserToGroupsMappingService(conf)
+ val builder = new MiniDFSCluster.Builder(conf)
+ hadoop = builder.build
+ conf.set("fs.defaultFS", "hdfs://localhost:" + hadoop.getNameNodePort + "/")
+ HdfsBasedMQTTStreamSource.hadoopConfig = conf
+ hadoop
+ }
+ }
+
+ def shutdownHadoop(): Unit = {
+ if (null != hadoop) {
+ hadoop.shutdown(true)
+ }
+ hadoop = null
+ }
+}
+
+class BasicHDFSBasedMQTTSourceSuite extends HDFSBasedMQTTStreamSourceSuite {
+
+ test("basic usage") {
+
+ val sendMessage = "MQTT is a message queue."
+
+ val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame()
+
+ val query = writeStreamResults(sqlContext, dataFrame)
+ mqttTestUtils.publishData("test", sendMessage)
+ query.processAllAvailable()
+ query.awaitTermination(10000)
+
+ val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext)
+
+ assert(resultBuffer.size == 1)
+ assert(resultBuffer.head == sendMessage)
+ }
+
+ test("Send and receive 50 messages.") {
+
+ val sendMessage = "MQTT is a message queue."
+
+ val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame()
+
+ val q = writeStreamResults(sqlContext, dataFrame)
+
+ mqttTestUtils.publishData("test", sendMessage, 50)
+ q.processAllAvailable()
+ q.awaitTermination(10000)
+
+ val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext)
+
+ assert(resultBuffer.size == 50)
+ assert(resultBuffer.head == sendMessage)
+ }
+
+ test("no server up") {
+ val provider = new HDFSMQTTSourceProvider
+ val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext
+ intercept[MqttException] {
+ provider.createSource(
+ sqlContext,
+ s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint/0",
+ Some(MQTTStreamConstants.SCHEMA_DEFAULT),
+ "org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider",
+ Map("brokerUrl" -> "tcp://localhost:1881", "topic" -> "test")
+ )
+ }
+ }
+
+ test("params not provided.") {
+ val provider = new HDFSMQTTSourceProvider
+ val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext
+ intercept[IllegalArgumentException] {
+ provider.createSource(
+ sqlContext,
+ s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint/0",
+ Some(MQTTStreamConstants.SCHEMA_DEFAULT),
+ "org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider",
+ Map()
+ )
+ }
+ }
+}