[BAHIR-183] Using HDFS for saving message for mqtt source. Closes #78
Project: http://git-wip-us.apache.org/repos/asf/bahir/repo Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/172d7096 Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/172d7096 Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/172d7096 Branch: refs/heads/master Commit: 172d7096147cd0be70687af893a4d71380ce47bf Parents: 63878bf Author: wangyanlin01 <[email protected]> Authored: Sun Dec 2 11:00:21 2018 +0800 Committer: Luciano Resende <[email protected]> Committed: Sat Dec 15 18:39:19 2018 -0300 ---------------------------------------------------------------------- sql-streaming-mqtt/pom.xml | 12 + ....apache.spark.sql.sources.DataSourceRegister | 3 +- .../sql/streaming/mqtt/CachedMQTTClient.scala | 2 +- .../sql/streaming/mqtt/MQTTStreamSink.scala | 2 +- .../sql/streaming/mqtt/MQTTStreamSource.scala | 2 +- .../bahir/sql/streaming/mqtt/MQTTUtils.scala | 15 +- .../spark/sql/mqtt/HDFSMQTTSourceProvider.scala | 64 +++ .../sql/mqtt/HdfsBasedMQTTStreamSource.scala | 400 +++++++++++++++++++ .../mqtt/HDFSBasedMQTTStreamSourceSuite.scala | 198 +++++++++ 9 files changed, 689 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/pom.xml ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/pom.xml b/sql-streaming-mqtt/pom.xml index 63497dc..05a3fff 100644 --- a/sql-streaming-mqtt/pom.xml +++ b/sql-streaming-mqtt/pom.xml @@ -85,6 +85,18 @@ <version>5.13.3</version> <scope>test</scope> </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-hdfs</artifactId> + <version>2.6.5</version> + <classifier>tests</classifier> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-common</artifactId> + <version>2.6.5</version> + <classifier>tests</classifier> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index d3899e6..1920a6b 100644 --- a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -16,4 +16,5 @@ # org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider -org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider \ No newline at end of file +org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider +org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider \ No newline at end of file http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala index fed2601..8925e93 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala @@ -66,7 +66,7 @@ private[mqtt] object CachedMQTTClient extends Logging { private def createMqttClient(config: Map[String, String]): (MqttClient, MqttClientPersistence) = { - val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _) = + val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _, _, _, _) = MQTTUtils.parseConfigParams(config) val client = new MqttClient(brokerUrl, clientId, persistence) val callback = new MqttCallbackExtended() { http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala index f449e57..846765c 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala @@ -52,7 +52,7 @@ class MQTTStreamWriter (schema: StructType, parameters: DataSourceOptions) initialize() private def initialize(): Unit = { - val (_, _, topic_, _, _, qos_) = MQTTUtils.parseConfigParams( + val (_, _, topic_, _, _, qos_, _, _, _) = MQTTUtils.parseConfigParams( collection.immutable.HashMap() ++ parameters.asMap().asScala ) topic = topic_ http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala index 98bc60e..a40ff51 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala @@ -244,7 +244,7 @@ class MQTTStreamSourceProvider extends DataSourceV2 } import scala.collection.JavaConverters._ - val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos) = + val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos, _, _, _) = MQTTUtils.parseConfigParams(collection.immutable.HashMap() ++ parameters.asMap().asScala) new MQTTStreamSource(parameters, brokerUrl, persistence, topic, clientId, http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala index f0a6f1a..9df46bc 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.bahir.utils.Logging - -private[mqtt] object MQTTUtils extends Logging { +object MQTTUtils extends Logging { // Since data source configuration properties are case-insensitive, // we have to introduce our own keys. Also, good for vendor independence. private[mqtt] val sslParamMapping = Map( @@ -45,8 +44,8 @@ private[mqtt] object MQTTUtils extends Logging { "ssl.trust.manager" -> "com.ibm.ssl.trustManager" ) - private[mqtt] def parseConfigParams(config: Map[String, String]): - (String, String, String, MqttClientPersistence, MqttConnectOptions, Int) = { + def parseConfigParams(config: Map[String, String]): + (String, String, String, MqttClientPersistence, MqttConnectOptions, Int, Long, Long, Int) = { def e(s: String) = new IllegalArgumentException(s) val parameters = CaseInsensitiveMap(config) @@ -84,6 +83,11 @@ private[mqtt] object MQTTUtils extends Logging { val autoReconnect: Boolean = parameters.getOrElse("autoReconnect", "false").toBoolean val maxInflight: Int = parameters.getOrElse("maxInflight", "60").toInt + val maxBatchMessageNum = parameters.getOrElse("maxBatchMessageNum", s"${Long.MaxValue}").toLong + val maxBatchMessageSize = parameters.getOrElse("maxBatchMessageSize", + s"${Long.MaxValue}").toLong + val maxRetryNumber = parameters.getOrElse("maxRetryNum", "3").toInt + val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions() mqttConnectOptions.setAutomaticReconnect(autoReconnect) mqttConnectOptions.setCleanSession(cleanSession) @@ -105,6 +109,7 @@ private[mqtt] object MQTTUtils extends Logging { }) mqttConnectOptions.setSSLProperties(sslProperties) - (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos) + (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos, + maxBatchMessageNum, maxBatchMessageSize, maxRetryNumber) } } http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala new file mode 100644 index 0000000..f38d842 --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.mqtt + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types.StructType + +import org.apache.bahir.sql.streaming.mqtt.{MQTTStreamConstants, MQTTUtils} + +/** + * The provider class for creating MQTT source. + * This provider throw IllegalArgumentException if 'brokerUrl' or 'topic' parameter + * is not set in options. + */ +class HDFSMQTTSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { + + override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType], + providerName: String, parameters: Map[String, String]): (String, StructType) = { + ("hdfs-mqtt", MQTTStreamConstants.SCHEMA_DEFAULT) + } + + override def createSource(sqlContext: SQLContext, metadataPath: String, + schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = { + + val parsedResult = MQTTUtils.parseConfigParams(parameters) + + new HdfsBasedMQTTStreamSource( + sqlContext, + metadataPath, + parsedResult._1, // brokerUrl + parsedResult._2, // clientId + parsedResult._3, // topic + parsedResult._5, // mqttConnectionOptions + parsedResult._6, // qos + parsedResult._7, // maxBatchMessageNum + parsedResult._8, // maxBatchMessageSize + parsedResult._9 // maxRetryNum + ) + } + + override def shortName(): String = "hdfs-mqtt" +} + +object HDFSMQTTSourceProvider { + val SEP = "##" +} http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala new file mode 100644 index 0000000..e6e202b --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.mqtt + +import java.io.IOException +import java.sql.Timestamp +import java.util.Calendar +import java.util.concurrent.locks.{Lock, ReentrantLock} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path, PathFilter} +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.FileContextManager +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.bahir.sql.streaming.mqtt.{LongOffset, MQTTStreamConstants} + +/** + * A Text based mqtt stream source, it interprets the payload of each incoming message by converting + * the bytes to String using Charset.defaultCharset as charset. Each value is associated with a + * timestamp of arrival of the message on the source. It can be used to operate a window on the + * incoming stream. + * + * @param sqlContext Spark provided, SqlContext. + * @param metadataPath meta data path + * @param brokerUrl url MqttClient connects to. + * @param topic topic MqttClient subscribes to. + * @param clientId clientId, this client is assoicated with. + * Provide the same value to recover a stopped client. + * @param mqttConnectOptions an instance of MqttConnectOptions for this Source. + * @param qos the maximum quality of service to subscribe each topic at. + * Messages published at a lower quality of service will be received + * at the published QoS. Messages published at a higher quality of + * service will be received using the QoS specified on the subscribe. + * @param maxBatchNumber the max message number to process in one batch. + * @param maxBatchSize the max total size in one batch, measured in bytes number. + */ +class HdfsBasedMQTTStreamSource( + sqlContext: SQLContext, + metadataPath: String, + brokerUrl: String, + clientId: String, + topic: String, + mqttConnectOptions: MqttConnectOptions, + qos: Int, + maxBatchNumber: Long = Long.MaxValue, + maxBatchSize: Long = Long.MaxValue, + maxRetryNumber: Int = 3 +) extends Source with Logging { + + import HDFSMQTTSourceProvider.SEP + + override def schema: StructType = MQTTStreamConstants.SCHEMA_DEFAULT + + // Last batch offset file index + private var lastOffset: Long = -1L + + // Current data file index to write messages. + private var currentMessageDataFileOffset: Long = 0L + + // FileSystem instance for storing received messages. + private var fs: FileSystem = _ + private var messageStoreOutputStream: FSDataOutputStream = _ + + // total message number received for current batch. + private var messageNumberForCurrentBatch: Int = 0 + // total message size received for + private var messageSizeForCurrentBatch: Int = 0 + + private val minBatchesToRetain = sqlContext.sparkSession.sessionState.conf.minBatchesToRetain + + // the consecutive fail number, cannot exceed the `maxRetryNumber` + private var consecutiveFailNum = 0 + + private var client: MqttClient = _ + + private val lock: Lock = new ReentrantLock() + + private val hadoopConfig: Configuration = if (HdfsBasedMQTTStreamSource.hadoopConfig != null) { + logInfo("using setted hadoop configuration!") + HdfsBasedMQTTStreamSource.hadoopConfig + } else { + logInfo("create a new configuration.") + new Configuration() + } + + private val rootCheckpointPath = { + val path = new Path(metadataPath).getParent.getParent.toUri.toString + logInfo(s"get rootCheckpointPath $path") + path + } + + private val receivedDataPath = s"$rootCheckpointPath/receivedMessages" + + // lazily init latest offset from offset WAL log + private lazy val recoveredLatestOffset = { + // the index of this source, parsing from metadata path + val currentSourceIndex = { + if (!metadataPath.isEmpty) { + metadataPath.substring(metadataPath.lastIndexOf("/") + 1).toInt + } else { + -1 + } + } + if (currentSourceIndex >= 0) { + val offsetLog = new OffsetSeqLog(sqlContext.sparkSession, + new Path(rootCheckpointPath, "offsets").toUri.toString) + // get the latest offset from WAL log + offsetLog.getLatest() match { + case Some((batchId, _)) => + logInfo(s"get latest batch $batchId") + Some(batchId) + case None => + logInfo("no offset avaliable in offset log") + None + } + } else { + logInfo("checkpoint path is not set") + None + } + } + + initialize() + + // Change data file if reach flow control threshold for one batch. + // Not thread safe. + private def startWriteNewDataFile(): Unit = { + if (messageStoreOutputStream != null) { + logInfo(s"Need to write a new data file," + + s" close current data file index $currentMessageDataFileOffset") + messageStoreOutputStream.flush() + messageStoreOutputStream.hsync() + messageStoreOutputStream.close() + messageStoreOutputStream = null + } + currentMessageDataFileOffset += 1 + messageSizeForCurrentBatch = 0 + messageNumberForCurrentBatch = 0 + messageStoreOutputStream = null + } + + // not thread safe + private def addReceivedMessageInfo(messageNum: Int, messageSize: Int): Unit = { + messageSizeForCurrentBatch += messageSize + messageNumberForCurrentBatch += messageNum + } + + // not thread safe + private def hasNewMessageForCurrentBatch(): Boolean = { + currentMessageDataFileOffset > lastOffset + 1 || messageNumberForCurrentBatch > 0 + } + + private def withLock[T](body: => T): T = { + lock.lock() + try body + finally lock.unlock() + } + + private def initialize(): Unit = { + + // recover lastOffset from WAL log + if (recoveredLatestOffset.nonEmpty) { + lastOffset = recoveredLatestOffset.get + logInfo(s"Recover lastOffset value ${lastOffset}") + } + + fs = FileSystem.get(hadoopConfig) + + // recover message data file offset from hdfs + val dataPath = new Path(receivedDataPath) + if (fs.exists(dataPath)) { + val fileManager = new FileContextManager(dataPath, hadoopConfig) + val dataFileIndexs = fileManager.list(dataPath, new PathFilter { + private def isBatchFile(path: Path) = { + try { + path.getName.toLong + true + } catch { + case _: NumberFormatException => false + } + } + + override def accept(path: Path): Boolean = isBatchFile(path) + }).map(_.getPath.getName.toLong) + if (dataFileIndexs.nonEmpty) { + currentMessageDataFileOffset = dataFileIndexs.max + 1 + assert(currentMessageDataFileOffset >= lastOffset + 1, + s"Recovered invalid message data file offset $currentMessageDataFileOffset," + + s"do not match with lastOffset $lastOffset") + logInfo(s"Recovered last message data file offset: ${currentMessageDataFileOffset - 1}, " + + s"start from $currentMessageDataFileOffset") + } else { + logInfo("No old data file exist, start data file index from 0") + currentMessageDataFileOffset = 0 + } + } else { + logInfo(s"Create data dir $receivedDataPath, start data file index from 0") + fs.mkdirs(dataPath) + currentMessageDataFileOffset = 0 + } + + client = new MqttClient(brokerUrl, clientId, new MemoryPersistence()) + + val callback = new MqttCallbackExtended() { + + override def messageArrived(topic: String, message: MqttMessage): Unit = { + withLock[Unit] { + val messageSize = message.getPayload.size + // check if have reached the max number or max size for current batch. + if (messageNumberForCurrentBatch + 1 > maxBatchNumber + || messageSizeForCurrentBatch + messageSize > maxBatchSize) { + startWriteNewDataFile() + } + // write message content to data file + if (messageStoreOutputStream == null) { + val path = new Path(s"${receivedDataPath}/${currentMessageDataFileOffset}") + if (fs.createNewFile(path)) { + logInfo(s"Create new message data file ${path.toUri.toString} success!") + } else { + throw new IOException(s"${path.toUri.toString} already exist," + + s"make sure do use unique checkpoint path for each app.") + } + messageStoreOutputStream = fs.append(path) + } + + messageStoreOutputStream.writeBytes(s"${message.getId}${SEP}") + messageStoreOutputStream.writeBytes(s"${topic}${SEP}") + val timestamp = Calendar.getInstance().getTimeInMillis().toString + messageStoreOutputStream.writeBytes(s"${timestamp}${SEP}") + messageStoreOutputStream.write(message.getPayload()) + messageStoreOutputStream.writeBytes("\n") + addReceivedMessageInfo(1, messageSize) + consecutiveFailNum = 0 + logInfo(s"Message arrived, topic: $topic, message payload $message, " + + s"messageId: ${message.getId}, message size: ${messageSize}") + } + } + + override def deliveryComplete(token: IMqttDeliveryToken): Unit = { + // callback for publisher, no need here. + } + + override def connectionLost(cause: Throwable): Unit = { + // auto reconnection is enabled, so just add a log here. + withLock[Unit] { + consecutiveFailNum += 1 + logWarning(s"Connection to mqtt server lost, " + + s"consecutive fail number $consecutiveFailNum", cause) + } + } + + override def connectComplete(reconnect: Boolean, serverURI: String): Unit = { + logInfo(s"Connect complete $serverURI. Is it a reconnect?: $reconnect") + } + } + client.setCallback(callback) + client.connect(mqttConnectOptions) + client.subscribe(topic, qos) + } + + /** Stop this source and free any resources it has allocated. */ + override def stop(): Unit = { + logInfo("Stop mqtt source.") + client.disconnect() + client.close() + withLock[Unit] { + if (messageStoreOutputStream != null) { + messageStoreOutputStream.hflush() + messageStoreOutputStream.hsync() + messageStoreOutputStream.close() + messageStoreOutputStream = null + } + fs.close() + } + } + + /** Returns the maximum available offset for this source. */ + override def getOffset: Option[Offset] = { + withLock[Option[Offset]] { + assert(consecutiveFailNum < maxRetryNumber, + s"Write message data fail continuously for ${maxRetryNumber} times.") + val result = if (!hasNewMessageForCurrentBatch()) { + if (lastOffset == -1) { + // first submit and no message has arrived. + None + } else { + // no message has arrived for this batch. + Some(LongOffset(lastOffset)) + } + } else { + // check if currently write the batch to be executed. + if (currentMessageDataFileOffset == lastOffset + 1) { + startWriteNewDataFile() + } + lastOffset += 1 + Some(LongOffset(lastOffset)) + } + logInfo(s"getOffset result $result") + result + } + } + + /** + * Returns the data that is between the offsets (`start`, `end`]. + * The batch return the data in file ${checkpointPath}/receivedMessages/${end}. + * `Start` and `end` value have the relationship: `end value` = `start valud` + 1, + * if `start` is not None. + */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + withLock[Unit]{ + assert(consecutiveFailNum < maxRetryNumber, + s"Write message data fail continuously for ${maxRetryNumber} times.") + } + logInfo(s"getBatch with start = $start, end = $end") + val endIndex = getOffsetValue(end) + if (start.nonEmpty) { + val startIndex = getOffsetValue(start.get) + assert(startIndex + 1 == endIndex, + s"start offset: ${startIndex} and end offset: ${endIndex} do not match") + } + logTrace(s"Create a data frame using hdfs file $receivedDataPath/$endIndex") + val rdd = sqlContext.sparkContext.textFile(s"$receivedDataPath/$endIndex") + .map{case str => + // calculate message in + val idIndex = str.indexOf(SEP) + val messageId = str.substring(0, idIndex).toInt + // get topic + var subStr = str.substring(idIndex + SEP.length) + val topicIndex = subStr.indexOf(SEP) + val topic = UTF8String.fromString(subStr.substring(0, topicIndex)) + // get timestamp + subStr = subStr.substring(topicIndex + SEP.length) + val timestampIndex = subStr.indexOf(SEP) + /* + val timestamp = Timestamp.valueOf( + MQTTStreamConstants.DATE_FORMAT.format(subStr.substring(0, timestampIndex).toLong)) + */ + val timestamp = subStr.substring(0, timestampIndex).toLong + // get playload + subStr = subStr.substring(timestampIndex + SEP.length) + val payload = UTF8String.fromString(subStr).getBytes + InternalRow(messageId, topic, payload, timestamp) + } + sqlContext.internalCreateDataFrame(rdd, MQTTStreamConstants.SCHEMA_DEFAULT, true) + } + + /** + * Remove the data file for the offset. + * + * @param end the end of offset that all data has been committed. + */ + override def commit(end: Offset): Unit = { + val offsetValue = getOffsetValue(end) + if (offsetValue >= minBatchesToRetain) { + val deleteDataFileOffset = offsetValue - minBatchesToRetain + try { + fs.delete(new Path(s"$receivedDataPath/$deleteDataFileOffset"), false) + logInfo(s"Delete committed offset data file $deleteDataFileOffset success!") + } catch { + case e: Exception => + logWarning(s"Delete committed offset data file $deleteDataFileOffset failed. ", e) + } + } + } + + private def getOffsetValue(offset: Offset): Long = { + val offsetValue = offset match { + case o: LongOffset => o.offset + case so: SerializedOffset => + so.json.toLong + } + offsetValue + } +} +object HdfsBasedMQTTStreamSource { + + var hadoopConfig: Configuration = _ +} http://git-wip-us.apache.org/repos/asf/bahir/blob/172d7096/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala new file mode 100644 index 0000000..777db16 --- /dev/null +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.MiniDFSCluster +import org.apache.hadoop.security.Groups +import org.eclipse.paho.client.mqttv3.MqttException +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.sql._ +import org.apache.spark.sql.mqtt.{HdfsBasedMQTTStreamSource, HDFSMQTTSourceProvider} +import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQuery} + +import org.apache.bahir.utils.FileHelper + +class HDFSBasedMQTTStreamSourceSuite + extends SparkFunSuite + with SharedSparkContext + with BeforeAndAfter { + + protected var mqttTestUtils: MQTTTestUtils = _ + protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test/") + protected var hadoop: MiniDFSCluster = _ + + before { + tempDir.mkdirs() + if (!tempDir.exists()) { + throw new IllegalStateException("Unable to create temp directories.") + } + tempDir.deleteOnExit() + mqttTestUtils = new MQTTTestUtils(tempDir) + mqttTestUtils.setup() + hadoop = HDFSTestUtils.prepareHadoop() + } + + after { + mqttTestUtils.teardown() + HDFSTestUtils.shutdownHadoop() + FileHelper.deleteFileQuietly(tempDir) + } + + protected val tmpDir: String = tempDir.getAbsolutePath + + protected def writeStreamResults(sqlContext: SQLContext, dataFrame: DataFrame): StreamingQuery = { + import sqlContext.implicits._ + val query: StreamingQuery = dataFrame.selectExpr("CAST(payload AS STRING)").as[String] + .writeStream.format("csv").start(s"$tempDir/t.csv") + while (!query.status.isTriggerActive) { + Thread.sleep(20) + } + query + } + + protected def readBackStreamingResults(sqlContext: SQLContext): mutable.Buffer[String] = { + import sqlContext.implicits._ + val asList = + sqlContext.read + .csv(s"$tmpDir/t.csv").as[String] + .collectAsList().asScala + asList + } + + protected def createStreamingDataFrame(dir: String = tmpDir): (SQLContext, DataFrame) = { + + val sqlContext: SQLContext = SparkSession.builder() + .getOrCreate().sqlContext + + sqlContext.setConf("spark.sql.streaming.checkpointLocation", + s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint") + + val ds: DataStreamReader = + sqlContext.readStream.format("org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider") + .option("topic", "test").option("clientId", "clientId").option("connectionTimeout", "120") + .option("keepAlive", "1200").option("autoReconnect", "false") + .option("cleanSession", "true").option("QoS", "2") + val dataFrame = ds.load("tcp://" + mqttTestUtils.brokerUri) + (sqlContext, dataFrame) + } +} + +object HDFSTestUtils { + + private var hadoop: MiniDFSCluster = _ + + def prepareHadoop(): MiniDFSCluster = { + if (hadoop != null) { + hadoop + } else { + val baseDir = new File(System.getProperty("java.io.tmpdir") + "/hadoop").getAbsoluteFile + System.setProperty("HADOOP_USER_NAME", "test") + val conf = new Configuration + conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath) + conf.setBoolean("dfs.namenode.acls.enabled", true) + conf.setBoolean("dfs.permissions", true) + Groups.getUserToGroupsMappingService(conf) + val builder = new MiniDFSCluster.Builder(conf) + hadoop = builder.build + conf.set("fs.defaultFS", "hdfs://localhost:" + hadoop.getNameNodePort + "/") + HdfsBasedMQTTStreamSource.hadoopConfig = conf + hadoop + } + } + + def shutdownHadoop(): Unit = { + if (null != hadoop) { + hadoop.shutdown(true) + } + hadoop = null + } +} + +class BasicHDFSBasedMQTTSourceSuite extends HDFSBasedMQTTStreamSourceSuite { + + test("basic usage") { + + val sendMessage = "MQTT is a message queue." + + val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame() + + val query = writeStreamResults(sqlContext, dataFrame) + mqttTestUtils.publishData("test", sendMessage) + query.processAllAvailable() + query.awaitTermination(10000) + + val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext) + + assert(resultBuffer.size == 1) + assert(resultBuffer.head == sendMessage) + } + + test("Send and receive 50 messages.") { + + val sendMessage = "MQTT is a message queue." + + val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame() + + val q = writeStreamResults(sqlContext, dataFrame) + + mqttTestUtils.publishData("test", sendMessage, 50) + q.processAllAvailable() + q.awaitTermination(10000) + + val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext) + + assert(resultBuffer.size == 50) + assert(resultBuffer.head == sendMessage) + } + + test("no server up") { + val provider = new HDFSMQTTSourceProvider + val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext + intercept[MqttException] { + provider.createSource( + sqlContext, + s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint/0", + Some(MQTTStreamConstants.SCHEMA_DEFAULT), + "org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider", + Map("brokerUrl" -> "tcp://localhost:1881", "topic" -> "test") + ) + } + } + + test("params not provided.") { + val provider = new HDFSMQTTSourceProvider + val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext + intercept[IllegalArgumentException] { + provider.createSource( + sqlContext, + s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint/0", + Some(MQTTStreamConstants.SCHEMA_DEFAULT), + "org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider", + Map() + ) + } + } +}
