Repository: bahir Updated Branches: refs/heads/master e79a960fa -> aecd5fd9f
[BAHIR-49] Sink for SQL Streaming MQTT module Closes #68 Project: http://git-wip-us.apache.org/repos/asf/bahir/repo Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/aecd5fd9 Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/aecd5fd9 Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/aecd5fd9 Branch: refs/heads/master Commit: aecd5fd9f00e40b64ebe81269396bfdc42f8ed00 Parents: e79a960 Author: Lukasz Antoniak <[email protected]> Authored: Mon Jul 9 07:42:09 2018 +0200 Committer: Luciano Resende <[email protected]> Committed: Wed Nov 28 11:12:12 2018 +0100 ---------------------------------------------------------------------- sql-streaming-mqtt/README.md | 41 +++-- .../streaming/mqtt/JavaMQTTSinkWordCount.java | 91 ++++++++++ .../sql/streaming/mqtt/MQTTSinkWordCount.scala | 85 +++++++++ ....apache.spark.sql.sources.DataSourceRegister | 1 + .../sql/streaming/mqtt/CachedMQTTClient.scala | 125 ++++++++++++++ .../sql/streaming/mqtt/MQTTStreamSink.scala | 122 +++++++++++++ .../sql/streaming/mqtt/MQTTStreamSource.scala | 60 +------ .../bahir/sql/streaming/mqtt/MQTTUtils.scala | 117 +++++++++++++ .../streaming/mqtt/LocalMessageStoreSuite.scala | 2 +- .../streaming/mqtt/MQTTStreamSinkSuite.scala | 172 +++++++++++++++++++ .../streaming/mqtt/MQTTStreamSourceSuite.scala | 13 +- .../sql/streaming/mqtt/MQTTTestUtils.scala | 33 ++++ 12 files changed, 789 insertions(+), 73 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/README.md ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/README.md b/sql-streaming-mqtt/README.md index b7f0602..a426a11 100644 --- a/sql-streaming-mqtt/README.md +++ b/sql-streaming-mqtt/README.md @@ -1,4 +1,4 @@ -A library for reading data from MQTT Servers using Spark SQL Streaming ( or Structured streaming.). +A library for writing and reading data from MQTT Servers using Spark SQL Streaming (or Structured streaming). ## Linking @@ -26,16 +26,25 @@ This library is compiled for Scala 2.11 only, and intends to support Spark 2.0 o ## Examples -A SQL Stream can be created with data streams received through MQTT Server using, +SQL Stream can be created with data streams received through MQTT Server using: sqlContext.readStream .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") .option("topic", "mytopic") .load("tcp://localhost:1883") -## Enable recovering from failures. +SQL Stream may be also transferred into MQTT messages using: -Setting values for option `localStorage` and `clientId` helps in recovering in case of a restart, by restoring the state where it left off before the shutdown. + sqlContext.writeStream + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider") + .option("checkpointLocation", "/path/to/localdir") + .outputMode("complete") + .option("topic", "mytopic") + .load("tcp://localhost:1883") + +## Source recovering from failures + +Setting values for option `localStorage` and `clientId` helps in recovering in case of source restart, by restoring the state where it left off before the shutdown. sqlContext.readStream .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") @@ -44,14 +53,14 @@ Setting values for option `localStorage` and `clientId` helps in recovering in c .option("clientId", "some-client-id") .load("tcp://localhost:1883") -## Configuration options. +## Configuration options -This source uses [Eclipse Paho Java Client](https://eclipse.org/paho/clients/java/). Client API documentation is located [here](http://www.eclipse.org/paho/files/javadoc/index.html). +This connector uses [Eclipse Paho Java Client](https://eclipse.org/paho/clients/java/). Client API documentation is located [here](http://www.eclipse.org/paho/files/javadoc/index.html). - * `brokerUrl` A url MqttClient connects to. Set this or `path` as the url of the Mqtt Server. e.g. tcp://localhost:1883. + * `brokerUrl` An URL MqttClient connects to. Set this or `path` as the URL of the Mqtt Server. e.g. tcp://localhost:1883. * `persistence` By default it is used for storing incoming messages on disk. If `memory` is provided as value for this option, then recovery on restart is not supported. * `topic` Topic MqttClient subscribes to. - * `clientId` clientId, this client is assoicated with. Provide the same value to recover a stopped client. + * `clientId` clientId, this client is associated with. Provide the same value to recover a stopped source client. MQTT sink ignores client identifier, because Spark batch can be distributed across multiple workers whereas MQTT broker does not allow simultanous connections with same ID from multiple hosts. * `QoS` The maximum quality of service to subscribe each topic at. Messages published at a lower quality of service will be received at the published QoS. Messages published at a higher quality of service will be received using the QoS specified on the subscribe. * `username` Sets the user name to use for the connection to Mqtt Server. Do not set it, if server does not need this. Setting it empty will lead to errors. * `password` Sets the password to use for the connection. @@ -61,7 +70,17 @@ This source uses [Eclipse Paho Java Client](https://eclipse.org/paho/clients/jav * `mqttVersion` Same as `MqttConnectOptions.setMqttVersion`. * `maxInflight` Same as `MqttConnectOptions.setMaxInflight` * `autoReconnect` Same as `MqttConnectOptions.setAutomaticReconnect` - + +## Environment variables + +Custom environment variables allowing to manage MQTT connectivity performed by sink connector: + + * `spark.mqtt.client.connect.attempts` Number of attempts sink will try to connect to MQTT broker before failing. + * `spark.mqtt.client.connect.backoff` Delay in milliseconds to wait before retrying connection to the server. + * `spark.mqtt.connection.cache.timeout` Sink connector caches MQTT connections. Idle connections will be closed after timeout milliseconds. + * `spark.mqtt.client.publish.attempts` Number of attempts to publish the message before failing the task. + * `spark.mqtt.client.publish.backoff` Delay in milliseconds to wait before retrying send operation. + ### Scala API An example, for scala API to count words from incoming message stream. @@ -86,7 +105,7 @@ An example, for scala API to count words from incoming message stream. query.awaitTermination() -Please see `MQTTStreamWordCount.scala` for full example. +Please see `MQTTStreamWordCount.scala` for full example. Review `MQTTSinkWordCount.scala`, if interested in publishing data to MQTT broker. ### Java API @@ -119,7 +138,7 @@ An example, for Java API to count words from incoming message stream. query.awaitTermination(); -Please see `JavaMQTTStreamWordCount.java` for full example. +Please see `JavaMQTTStreamWordCount.java` for full example. Review `JavaMQTTSinkWordCount.java`, if interested in publishing data to MQTT broker. ## Best Practices. http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTSinkWordCount.java ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTSinkWordCount.java b/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTSinkWordCount.java new file mode 100644 index 0000000..8e5006a --- /dev/null +++ b/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTSinkWordCount.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.bahir.examples.sql.streaming.mqtt; + +import java.io.File; +import java.util.Arrays; +import java.util.Iterator; + +import org.apache.commons.io.FileUtils; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.streaming.StreamingQuery; + +/** + * Counts words in UTF-8 encoded, '\n' delimited text received from local socket + * and publishes results on MQTT topic. + * + * Usage: JavaMQTTSinkWordCount <port> <brokerUrl> <topic> + * <port> represents local network port on which program is listening for input. + * <brokerUrl> and <topic> describe the MQTT server that structured streaming + * would connect and send data. + * + * To run example on your local machine, a MQTT Server should be up and running. + * Linux users may leverage 'nc -lk <port>' to listen on local port and wait + * for Spark socket connection. + */ +public class JavaMQTTSinkWordCount { + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaMQTTSinkWordCount <port> <brokerUrl> <topic>"); + System.exit(1); + } + + String checkpointDir = System.getProperty("java.io.tmpdir") + "/mqtt-example/"; + // Remove checkpoint directory. + FileUtils.deleteDirectory(new File(checkpointDir)); + + Integer port = Integer.valueOf(args[0]); + String brokerUrl = args[1]; + String topic = args[2]; + + SparkSession spark = SparkSession.builder() + .appName("JavaMQTTSinkWordCount").master("local[4]") + .getOrCreate(); + + // Create DataFrame representing the stream of input lines from local network socket. + Dataset<String> lines = spark.readStream() + .format("socket") + .option("host", "localhost").option("port", port) + .load().select("value").as(Encoders.STRING()); + + // Split the lines into words. + Dataset<String> words = lines.flatMap(new FlatMapFunction<String, String>() { + @Override + public Iterator<String> call(String x) { + return Arrays.asList(x.split(" ")).iterator(); + } + }, Encoders.STRING()); + + // Generate running word count. + Dataset<Row> wordCounts = words.groupBy("value").count(); + + // Start publishing the counts to MQTT server. + StreamingQuery query = wordCounts.writeStream() + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider") + .option("checkpointLocation", checkpointDir) + .outputMode("complete") + .option("topic", topic) + .option("localStorage", checkpointDir) + .start(brokerUrl); + + query.awaitTermination(); + } +} http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTSinkWordCount.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTSinkWordCount.scala b/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTSinkWordCount.scala new file mode 100644 index 0000000..c869adc --- /dev/null +++ b/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTSinkWordCount.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.examples.sql.streaming.mqtt + +import java.io.File + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.SparkSession + +/** + * Counts words in UTF-8 encoded, '\n' delimited text received from local socket + * and publishes results on MQTT topic. + * + * Usage: MQTTSinkWordCount <port> <brokerUrl> <topic> + * <port> represents local network port on which program is listening for input. + * <brokerUrl> and <topic> describe the MQTT server that structured streaming + * would connect and send data. + * + * To run example on your local machine, a MQTT Server should be up and running. + * Linux users may leverage 'nc -lk <port>' to listen on local port and wait + * for Spark socket connection. + */ +object MQTTSinkWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + // scalastyle:off + System.err.println("Usage: MQTTSinkWordCount <port> <brokerUrl> <topic>") + // scalastyle:on + System.exit(1) + } + + val checkpointDir = System.getProperty("java.io.tmpdir") + "/mqtt-example/" + // Remove checkpoint directory. + FileUtils.deleteDirectory(new File(checkpointDir)) + + val port = args(0) + val brokerUrl = args(1) + val topic = args(2) + + val spark = SparkSession.builder + .appName("MQTTSinkWordCount").master("local[4]") + .getOrCreate() + + import spark.implicits._ + + // Create DataFrame representing the stream of input lines from local network socket. + val lines = spark.readStream + .format("socket") + .option("host", "localhost").option("port", port) + .load().select("value").as[String] + + // Split the lines into words. + val words = lines.flatMap(_.split(" ")) + + // Generate running word count. + val wordCounts = words.groupBy("value").count() + + // Start publishing the counts to MQTT server. + val query = wordCounts.writeStream + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider") + .option("checkpointLocation", checkpointDir) + .outputMode("complete") + .option("topic", topic) + .option("localStorage", checkpointDir) + .start(brokerUrl) + + query.awaitTermination() + } +} http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 634bc2f..d3899e6 100644 --- a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -15,4 +15,5 @@ # limitations under the License. # +org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider \ No newline at end of file http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala new file mode 100644 index 0000000..f825eea --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.util.concurrent.{ExecutionException, TimeUnit} + +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache, RemovalListener, RemovalNotification} +import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} +import org.eclipse.paho.client.mqttv3.{IMqttDeliveryToken, MqttCallbackExtended, MqttClient, MqttClientPersistence, MqttException, MqttMessage} +import scala.util.control.NonFatal + +import org.apache.spark.SparkEnv + +import org.apache.bahir.utils.Logging + + +private[mqtt] object CachedMQTTClient extends Logging { + private lazy val cacheExpireTimeout: Long = + SparkEnv.get.conf.getTimeAsMs("spark.mqtt.connection.cache.timeout", "10m") + private lazy val connectAttempts: Int = + SparkEnv.get.conf.getInt("spark.mqtt.client.connect.attempts", -1) + private lazy val connectBackoff: Long = + SparkEnv.get.conf.getTimeAsMs("spark.mqtt.client.connect.backoff", "5s") + + private val cacheLoader = new CacheLoader[Seq[(String, String)], + (MqttClient, MqttClientPersistence)] { + override def load(config: Seq[(String, String)]): (MqttClient, MqttClientPersistence) = { + log.debug(s"Creating new MQTT client with params: $config") + createMqttClient(Map(config.map(s => s._1 -> s._2): _*)) + } + } + + private val removalListener = new RemovalListener[Seq[(String, String)], + (MqttClient, MqttClientPersistence)]() { + override def onRemoval(notification: RemovalNotification[Seq[(String, String)], + (MqttClient, MqttClientPersistence)]): Unit = { + val params: Seq[(String, String)] = notification.getKey + val client: MqttClient = notification.getValue._1 + val persistence: MqttClientPersistence = notification.getValue._2 + log.debug(s"Evicting MQTT client $client params: $params, due to ${notification.getCause}") + closeMqttClient(params, client, persistence) + } + } + + private lazy val cache: LoadingCache[Seq[(String, String)], + (MqttClient, MqttClientPersistence)] = + CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS) + .removalListener(removalListener) + .build[Seq[(String, String)], (MqttClient, MqttClientPersistence)](cacheLoader) + + private def createMqttClient(config: Map[String, String]): + (MqttClient, MqttClientPersistence) = { + val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _) = + MQTTUtils.parseConfigParams(config) + val client = new MqttClient(brokerUrl, clientId, persistence) + val callback = new MqttCallbackExtended() { + override def messageArrived(topic : String, message: MqttMessage): Unit = synchronized { + } + + override def deliveryComplete(token: IMqttDeliveryToken): Unit = { + } + + override def connectionLost(cause: Throwable): Unit = { + log.warn("Connection to mqtt server lost.", cause) + } + + override def connectComplete(reconnect: Boolean, serverURI: String): Unit = { + log.info(s"Connect complete $serverURI. Is it a reconnect?: $reconnect") + } + } + client.setCallback(callback) + Retry(connectAttempts, connectBackoff, classOf[MqttException]) { + client.connect(mqttConnectOptions) + } + (client, persistence) + } + + private def closeMqttClient(params: Seq[(String, String)], + client: MqttClient, persistence: MqttClientPersistence): Unit = { + try { + client.disconnect() + persistence.close() + client.close() + } catch { + case NonFatal(e) => log.warn(s"Error while closing MQTT client ${e.getMessage}", e) + } + } + + private[mqtt] def getOrCreate(parameters: Map[String, String]): MqttClient = { + try { + cache.get(mapToSeq(parameters))._1 + } catch { + case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError) + if e.getCause != null => throw e.getCause + } + } + + private[mqtt] def close(parameters: Map[String, String]): Unit = { + cache.invalidate(mapToSeq(parameters)) + } + + private[mqtt] def clear(): Unit = { + log.debug("Cleaning MQTT client cache") + cache.invalidateAll() + } + + private def mapToSeq(parameters: Map[String, String]): Seq[(String, String)] = { + parameters.toSeq.sortBy(x => x._1) + } +} http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala new file mode 100644 index 0000000..8654b88 --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.nio.charset.Charset + +import scala.collection.JavaConverters._ + +import org.eclipse.paho.client.mqttv3.MqttException + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.sources.{PackedRowCommitMessage, PackedRowWriterFactory} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +import org.apache.bahir.utils.Logging + + +class MQTTStreamWriter (schema: StructType, parameters: DataSourceOptions) + extends StreamWriter with Logging { + private lazy val publishAttempts: Int = + SparkEnv.get.conf.getInt("spark.mqtt.client.publish.attempts", -1) + private lazy val publishBackoff: Long = + SparkEnv.get.conf.getTimeAsMs("spark.mqtt.client.publish.backoff", "5s") + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + + private var topic: String = _ + private var qos: Int = -1 + + initialize() + private def initialize(): Unit = { + val (_, _, topic_, _, _, qos_) = MQTTUtils.parseConfigParams( + collection.immutable.HashMap() ++ parameters.asMap().asScala + ) + topic = topic_ + qos = qos_ + } + + override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + commit(messages) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + val rows = messages.collect { + case PackedRowCommitMessage(rs) => rs + }.flatten + + // Skipping client identifier as single batch can be distributed to multiple + // Spark worker process. MQTT server does not support two connections + // declaring same client ID at given point in time. + val params_ = Seq() ++ parameters.asMap().asScala.toSeq.filterNot( + _._1.equalsIgnoreCase("clientId") + ) + // IMPL Note: Had to declare new value reference due to serialization requirements. + val topic_ = topic + val qos_ = qos + val publishAttempts_ = publishAttempts + val publishBackoff_ = publishBackoff + + val data = spark.createDataFrame(rows.toList.asJava, schema) + data.foreachPartition ( + iterator => iterator.foreach( + row => { + val client = CachedMQTTClient.getOrCreate(params_.toMap) + val message = row.mkString.getBytes(Charset.defaultCharset()) + Retry(publishAttempts_, publishBackoff_, classOf[MqttException]) { + // In case of errors, retry sending the message. + client.publish(topic_, message, qos_, false) + } + } + ) + ) + } + + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +case class MQTTRelation(override val sqlContext: SQLContext, data: DataFrame) + extends BaseRelation { + override def schema: StructType = data.schema +} + +class MQTTStreamSinkProvider extends DataSourceV2 with StreamWriteSupport + with DataSourceRegister with CreatableRelationProvider { + override def createStreamWriter(queryId: String, schema: StructType, + mode: OutputMode, options: DataSourceOptions): StreamWriter = { + new MQTTStreamWriter(schema, options) + } + + override def createRelation(sqlContext: SQLContext, mode: SaveMode, + parameters: Map[String, String], data: DataFrame): BaseRelation = { + MQTTRelation(sqlContext, data) + } + + override def shortName(): String = "mqtt" +} http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala index 2f75ee2..98bc60e 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala @@ -29,7 +29,6 @@ import scala.collection.immutable.IndexedSeq import scala.collection.mutable.ListBuffer import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.{MemoryPersistence, MqttDefaultFilePersistence} import org.apache.spark.sql._ import org.apache.spark.sql.sources.DataSourceRegister @@ -83,7 +82,7 @@ class MQTTMessage(m: MqttMessage, val topic: String) extends Serializable { * incoming messages on disk. If memory is provided as option, then recovery on * restart is not supported. * @param topic topic MqttClient subscribes to. - * @param clientId clientId, this client is assoicated with. Provide the same value to recover + * @param clientId clientId, this client is associated with. Provide the same value to recover * a stopped client. * @param mqttConnectOptions an instance of MqttConnectOptions for this Source. * @param qos the maximum quality of service to subscribe each topic at.Messages published at @@ -244,60 +243,11 @@ class MQTTStreamSourceProvider extends DataSourceV2 throw e("The mqtt source does not support a user-specified schema.") } - val brokerUrl = parameters.get("brokerUrl").orElse(parameters.get("path").orElse(null)) + import scala.collection.JavaConverters._ + val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos) = + MQTTUtils.parseConfigParams(collection.immutable.HashMap() ++ parameters.asMap().asScala) - if (brokerUrl == null) { - throw e("Please provide a broker url, with option(\"brokerUrl\", ...).") - } - - val persistence: MqttClientPersistence = parameters.get("persistence").orElse("") match { - case "memory" => new MemoryPersistence() - case _ => val localStorage: String = parameters.get("localStorage").orElse("") - localStorage match { - case "" => new MqttDefaultFilePersistence() - case x => new MqttDefaultFilePersistence(x) - } - } - - // if default is subscribe everything, it leads to getting lot unwanted system messages. - val topic: String = parameters.get("topic").orElse(null) - if (topic == null) { - throw e("Please specify a topic, by .options(\"topic\",...)") - } - - val clientId: String = parameters.get("clientId").orElse { - log.warn("If `clientId` is not set, a random value is picked up." + - " Recovering from failure is not supported in such a case.") - MqttClient.generateClientId()} - - val username: String = parameters.get("username").orElse(null) - val password: String = parameters.get("password").orElse(null) - - val connectionTimeout: Int = parameters.get("connectionTimeout").orElse( - MqttConnectOptions.CONNECTION_TIMEOUT_DEFAULT.toString).toInt - val keepAlive: Int = parameters.get("keepAlive").orElse(MqttConnectOptions - .KEEP_ALIVE_INTERVAL_DEFAULT.toString).toInt - val mqttVersion: Int = parameters.get("mqttVersion").orElse(MqttConnectOptions - .MQTT_VERSION_DEFAULT.toString).toInt - val cleanSession: Boolean = parameters.get("cleanSession").orElse("true").toBoolean - val qos: Int = parameters.get("QoS").orElse("1").toInt - val autoReconnect: Boolean = parameters.get("autoReconnect").orElse("false").toBoolean - val maxInflight: Int = parameters.get("maxInflight").orElse("60").toInt - val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions() - mqttConnectOptions.setAutomaticReconnect(autoReconnect) - mqttConnectOptions.setCleanSession(cleanSession) - mqttConnectOptions.setConnectionTimeout(connectionTimeout) - mqttConnectOptions.setKeepAliveInterval(keepAlive) - mqttConnectOptions.setMqttVersion(mqttVersion) - mqttConnectOptions.setMaxInflight(maxInflight) - (username, password) match { - case (u: String, p: String) if u != null && p != null => - mqttConnectOptions.setUserName(u) - mqttConnectOptions.setPassword(p.toCharArray) - case _ => - } - - new MQTTStreamSource(parameters, brokerUrl, persistence, topic, clientId, + new MQTTStreamSource(parameters, brokerUrl, persistence, topic, clientId, mqttConnectOptions, qos) } override def shortName(): String = "mqtt" http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala new file mode 100644 index 0000000..79fe7a2 --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttConnectOptions} +import org.eclipse.paho.client.mqttv3.persist.{MemoryPersistence, MqttDefaultFilePersistence} + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + +import org.apache.bahir.utils.Logging + + +private[mqtt] object MQTTUtils extends Logging { + private[mqtt] def parseConfigParams(config: Map[String, String]): + (String, String, String, MqttClientPersistence, MqttConnectOptions, Int) = { + def e(s: String) = new IllegalArgumentException(s) + val parameters = CaseInsensitiveMap(config) + + val brokerUrl: String = parameters.getOrElse("brokerUrl", parameters.getOrElse("path", + throw e("Please provide a `brokerUrl` by specifying path or .options(\"brokerUrl\",...)"))) + + val persistence: MqttClientPersistence = parameters.get("persistence") match { + case Some("memory") => new MemoryPersistence() + case _ => val localStorage: Option[String] = parameters.get("localStorage") + localStorage match { + case Some(x) => new MqttDefaultFilePersistence(x) + case None => new MqttDefaultFilePersistence() + } + } + + // if default is subscribe everything, it leads to getting lot unwanted system messages. + val topic: String = parameters.getOrElse("topic", + throw e("Please specify a topic, by .options(\"topic\",...)")) + + val clientId: String = parameters.getOrElse("clientId", { + log.warn("If `clientId` is not set, a random value is picked up." + + "\nRecovering from failure is not supported in such a case.") + MqttClient.generateClientId()}) + + val username: Option[String] = parameters.get("username") + val password: Option[String] = parameters.get("password") + val connectionTimeout: Int = parameters.getOrElse("connectionTimeout", + MqttConnectOptions.CONNECTION_TIMEOUT_DEFAULT.toString).toInt + val keepAlive: Int = parameters.getOrElse("keepAlive", MqttConnectOptions + .KEEP_ALIVE_INTERVAL_DEFAULT.toString).toInt + val mqttVersion: Int = parameters.getOrElse("mqttVersion", MqttConnectOptions + .MQTT_VERSION_DEFAULT.toString).toInt + val cleanSession: Boolean = parameters.getOrElse("cleanSession", "false").toBoolean + val qos: Int = parameters.getOrElse("QoS", "1").toInt + val autoReconnect: Boolean = parameters.getOrElse("autoReconnect", "false").toBoolean + val maxInflight: Int = parameters.getOrElse("maxInflight", "60").toInt + + val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions() + mqttConnectOptions.setAutomaticReconnect(autoReconnect) + mqttConnectOptions.setCleanSession(cleanSession) + mqttConnectOptions.setConnectionTimeout(connectionTimeout) + mqttConnectOptions.setKeepAliveInterval(keepAlive) + mqttConnectOptions.setMqttVersion(mqttVersion) + mqttConnectOptions.setMaxInflight(maxInflight) + (username, password) match { + case (Some(u: String), Some(p: String)) => + mqttConnectOptions.setUserName(u) + mqttConnectOptions.setPassword(p.toCharArray) + case _ => + } + + (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos) + } +} + +private[mqtt] object Retry { + /** + * Retry invocation of given code. + * @param attempts Number of attempts to try executing given code. -1 represents infinity. + * @param pauseMs Number of backoff milliseconds. + * @param retryExceptions Types of exceptions to retry. + * @param code Function to execute. + * @tparam A Type parameter. + * @return Returns result of function execution or exception in case of failure. + */ + def apply[A](attempts: Int, pauseMs: Long, retryExceptions: Class[_]*)(code: => A): A = { + var result: Option[A] = None + var success = false + var remaining = attempts + while ( ! success ) { + try { + remaining -= 1 + result = Some( code ) + success = true + } + catch { + case e: Exception => + if (retryExceptions.contains(e.getClass) && (attempts == -1 || remaining > 0)) { + Thread.sleep(pauseMs) + } else { + throw e + } + } + } + result.get + } +} http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala index 0a2a079..d1bbe18 100644 --- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala @@ -58,7 +58,7 @@ class LocalMessageStoreSuite extends SparkFunSuite with BeforeAndAfter { assert(testData === deserialized) } - test("Store and retreive") { + test("Store and retrieve") { store.store(1, testData) val result: Seq[Int] = store.retrieve(1) assert(testData === result) http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSinkSuite.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSinkSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSinkSuite.scala new file mode 100644 index 0000000..14ea962 --- /dev/null +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSinkSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.io.File +import java.net.ConnectException +import java.util + +import org.eclipse.paho.client.mqttv3.MqttClient +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.Future + +import org.apache.spark.{SharedSparkContext, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.sources.PackedRowCommitMessage +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +import org.apache.bahir.utils.BahirUtils + + +class MQTTStreamSinkSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { + protected var mqttTestUtils: MQTTTestUtils = _ + protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test/") + protected val messages = new mutable.HashMap[Int, String] + protected var testClient: MqttClient = _ + + before { + mqttTestUtils = new MQTTTestUtils(tempDir) + mqttTestUtils.setup() + tempDir.mkdirs() + messages.clear() + testClient = mqttTestUtils.subscribeData("test", messages) + } + + after { + testClient.disconnect() + testClient.close() + mqttTestUtils.teardown() + BahirUtils.recursiveDeleteDir(tempDir) + } + + protected def createContextAndDF(messages: String*): (SQLContext, DataFrame) = { + val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext + sqlContext.setConf("spark.sql.streaming.checkpointLocation", tempDir.getAbsolutePath) + import sqlContext.sparkSession.implicits._ + val stream = new MemoryStream[String](1, sqlContext) + stream.addData(messages.toSeq) + (sqlContext, stream.toDF()) + } + + protected def sendToMQTT(dataFrame: DataFrame): StreamingQuery = { + dataFrame.writeStream + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider") + .option("topic", "test").option("localStorage", tempDir.getAbsolutePath) + .option("clientId", "clientId").option("QoS", "2") + .start("tcp://" + mqttTestUtils.brokerUri) + } +} + +class BasicMQTTSinkSuite extends MQTTStreamSinkSuite { + test("broker down") { + SparkEnv.get.conf.set("spark.mqtt.client.connect.attempts", "1") + SparkSession.setActiveSession(SparkSession.builder().getOrCreate()) + val provider = new MQTTStreamSinkProvider + val parameters = Map( + "brokerUrl" -> "tcp://localhost:1883", + "topic" -> "test", + "localStorage" -> tempDir.getAbsoluteFile.toString + ) + val schema = StructType(StructField("value", StringType) :: Nil) + val messages : Array[Row] = Array(new GenericRowWithSchema(Array("value1"), schema)) + val thrown: Exception = intercept[SparkException] { + provider.createStreamWriter( + "query1", schema, OutputMode.Complete(), new DataSourceOptions(parameters.asJava) + ).commit(1, Array(PackedRowCommitMessage(messages))) + } + // SparkException -> MqttException -> ConnectException + assert(thrown.getCause.getCause.isInstanceOf[ConnectException]) + } + + test("basic usage") { + val msg1 = "Hello, World!" + val msg2 = "MQTT is a message queue." + val (_, dataFrame) = createContextAndDF(msg1, msg2) + + sendToMQTT(dataFrame).awaitTermination(3000) + + assert(Set(msg1, msg2).equals(messages.values.toSet)) + } + + test("send and receive 100 messages") { + val msg = List.tabulate(100)(n => "Hello, World!" + n) + val (_, dataFrame) = createContextAndDF(msg: _*) + + sendToMQTT(dataFrame).awaitTermination(3000) + + assert(Set(msg: _*).equals(messages.values.toSet)) + } + + test("missing configuration") { + val provider = new MQTTStreamSinkProvider + val parameters = Map( + "brokerUrl" -> "tcp://localhost:1883", + "localStorage" -> tempDir.getAbsoluteFile.toString + ) + intercept[IllegalArgumentException] { + provider.createStreamWriter( + "query1", null, OutputMode.Complete(), new DataSourceOptions(parameters.asJava) + ) + } + intercept[IllegalArgumentException] { + provider.createStreamWriter( + "query1", null, OutputMode.Complete(), + new DataSourceOptions(new util.HashMap[String, String]) + ) + } + } +} + +class StressTestMQTTSink extends MQTTStreamSinkSuite { + // run with -Xmx1024m + test("Send and receive messages of size 100MB.") { + val freeMemory: Long = Runtime.getRuntime.freeMemory() + log.info(s"Available memory before test run is ${freeMemory / (1024 * 1024)}MB.") + val noOfMsgs: Int = 200 + val noOfBatches: Int = 10 + + val messageBuilder = new StringBuilder() + for (i <- 0 until (500 * 1024)) yield messageBuilder.append(((i % 26) + 65).toChar) + val message = messageBuilder.toString() + val (_, dataFrame) = createContextAndDF( + // each message is 50 KB + Array.fill(noOfMsgs / noOfBatches)(message): _* + ) + + import scala.concurrent.ExecutionContext.Implicits.global + Future { + for (_ <- 0 until noOfBatches.toInt) { + sendToMQTT(dataFrame) + } + } + def waitForMessages(): Boolean = { + messages.size == noOfMsgs + } + + mqttTestUtils.sleepUntil(waitForMessages(), 60000) + + assert(messages.size == noOfMsgs) + assert(messageBuilder.toString().equals(messages.head._2)) + } +} http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala index 2ce72da..bb82715 100644 --- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala @@ -74,10 +74,11 @@ class MQTTStreamSourceSuite extends SparkFunSuite with SharedSparkContext with B asList } - protected def createStreamingDataframe(dir: String = tmpDir, + protected def createStreamingDataFrame(dir: String = tmpDir, filePersistence: Boolean = false): (SQLContext, DataFrame) = { - val sqlContext: SQLContext = new SQLContext(sc) + val sqlContext: SQLContext = SparkSession.builder() + .getOrCreate().sqlContext sqlContext.setConf("spark.sql.streaming.checkpointLocation", tmpDir) @@ -104,7 +105,7 @@ class BasicMQTTSourceSuite extends MQTTStreamSourceSuite { val sendMessage = "MQTT is a message queue." - val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe() + val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame() val query = writeStreamResults(sqlContext, dataFrame) mqttTestUtils.publishData("test", sendMessage) @@ -121,7 +122,7 @@ class BasicMQTTSourceSuite extends MQTTStreamSourceSuite { val sendMessage = "MQTT is a message queue." - val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe() + val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame() val q = writeStreamResults(sqlContext, dataFrame) @@ -137,7 +138,7 @@ class BasicMQTTSourceSuite extends MQTTStreamSourceSuite { test("no server up") { val provider = new MQTTStreamSourceProvider - val sqlContext: SQLContext = new SQLContext(sc) + val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext val parameters = new DataSourceOptions(Map("brokerUrl" -> "tcp://localhost:1881", "topic" -> "test", "localStorage" -> tmpDir).asJava) intercept[MqttException] { @@ -174,7 +175,7 @@ class StressTestMQTTSource extends MQTTStreamSourceSuite { for (i <- 0 until (500 * 1024)) yield messageBuilder.append(((i % 26) + 65).toChar) val sendMessage = messageBuilder.toString() // each message is 50 KB - val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe() + val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame() val query = writeStreamResults(sqlContext, dataFrame) mqttTestUtils.publishData("test", sendMessage, noOfMsgs ) http://git-wip-us.apache.org/repos/asf/bahir/blob/aecd5fd9/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala index 817ec9a..893a145 100644 --- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala @@ -19,6 +19,9 @@ package org.apache.bahir.sql.streaming.mqtt import java.io.File import java.net.{ServerSocket, URI} +import java.nio.charset.Charset + +import scala.collection.mutable import org.apache.activemq.broker.{BrokerService, TransportConnector} import org.eclipse.paho.client.mqttv3._ @@ -103,4 +106,34 @@ class MQTTTestUtils(tempDir: File, port: Int = 0) extends Logging { } } + def subscribeData(topic: String, messages: mutable.Map[Int, String]): MqttClient = { + val client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), null) + val callback = new MqttCallbackExtended() { + override def messageArrived(topic_ : String, message: MqttMessage): Unit = synchronized { + messages.put(message.getId, new String(message.getPayload, Charset.defaultCharset())) + } + + override def deliveryComplete(token: IMqttDeliveryToken): Unit = { + } + + override def connectionLost(cause: Throwable): Unit = { + } + + override def connectComplete(reconnect: Boolean, serverURI: String): Unit = { + } + } + client.setCallback(callback) + client.connect() + client.subscribe(topic) + client + } + + def sleepUntil(predicate: => Boolean, timeout: Long): Unit = { + val deadline = System.currentTimeMillis() + timeout + while (System.currentTimeMillis() < deadline) { + Thread.sleep(1000) + if (predicate) return + } + } + }
