Repository: bahir Updated Branches: refs/heads/master 29d8c7622 -> 70539a35d
[BAHIR-39] Add SQL Streaming MQTT support This provides support for using MQTT sources for the new Spark Structured Streaming. This uses MQTT client persistence layer to provide minimal fault tolerance. Closes #13 Project: http://git-wip-us.apache.org/repos/asf/bahir/repo Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/c98dd0fe Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/c98dd0fe Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/c98dd0fe Branch: refs/heads/master Commit: c98dd0feefa79c70be65de06a411a2f9c4fc42dc Parents: 29d8c76 Author: Prashant Sharma <[email protected]> Authored: Tue Jul 26 13:47:15 2016 +0530 Committer: Luciano Resende <[email protected]> Committed: Sat Aug 6 09:08:06 2016 +0300 ---------------------------------------------------------------------- pom.xml | 2 + sql-streaming-mqtt/README.md | 121 +++++++++++ sql-streaming-mqtt/pom.xml | 121 +++++++++++ .../src/main/assembly/assembly.xml | 44 ++++ .../mqtt/examples/JavaMQTTStreamWordCount.java | 83 ++++++++ ....apache.spark.sql.sources.DataSourceRegister | 1 + .../sql/streaming/mqtt/MQTTStreamSource.scala | 191 +++++++++++++++++ .../bahir/sql/streaming/mqtt/MessageStore.scala | 112 ++++++++++ .../mqtt/examples/MQTTStreamWordCount.scala | 73 +++++++ .../org/apache/bahir/utils/BahirUtils.scala | 48 +++++ .../scala/org/apache/bahir/utils/Logging.scala | 25 +++ .../src/test/resources/log4j.properties | 27 +++ .../streaming/mqtt/LocalMessageStoreSuite.scala | 73 +++++++ .../streaming/mqtt/MQTTStreamSourceSuite.scala | 208 +++++++++++++++++++ .../sql/streaming/mqtt/MQTTTestUtils.scala | 102 +++++++++ 15 files changed, 1231 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index f076801..644f4f0 100644 --- a/pom.xml +++ b/pom.xml @@ -77,6 +77,7 @@ <modules> <module>streaming-akka</module> <module>streaming-mqtt</module> + <module>sql-streaming-mqtt</module> <module>streaming-twitter</module> <module>streaming-zeromq</module> </modules> @@ -444,6 +445,7 @@ <exclude>.settings</exclude> <exclude>.classpath</exclude> <exclude>.project</exclude> + <exclude>**/META-INF/**</exclude> <exclude>**/dependency-reduced-pom.xml</exclude> <exclude>**/target/**</exclude> <exclude>**/README.md</exclude> http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/README.md ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/README.md b/sql-streaming-mqtt/README.md new file mode 100644 index 0000000..fa222b1 --- /dev/null +++ b/sql-streaming-mqtt/README.md @@ -0,0 +1,121 @@ +A library for reading data from MQTT Servers using Spark SQL Streaming ( or Structured streaming.). + +## Linking + +Using SBT: + +```scala +libraryDependencies += "org.apache.bahir" %% "spark-sql-streaming-mqtt" % "2.0.0" +``` + +Using Maven: + +```xml +<dependency> + <groupId>org.apache.bahir</groupId> + <artifactId>spark-sql-streaming-mqtt_2.11</artifactId> + <version>2.0.0</version> +</dependency> +``` + +This library can also be added to Spark jobs launched through `spark-shell` or `spark-submit` by using the `--packages` command line option. +For example, to include it when starting the spark shell: + +``` +$ bin/spark-shell --packages org.apache.bahir:spark-sql-streaming-mqtt_2.11:2.0.0 +``` + +Unlike using `--jars`, using `--packages` ensures that this library and its dependencies will be added to the classpath. +The `--packages` argument can also be used with `bin/spark-submit`. + +This library is compiled for Scala 2.11 only, and intends to support Spark 2.0 onwards. + +## Examples + +A SQL Stream can be created with data streams received through MQTT Server using, + +```scala +sqlContext.readStream + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") + .option("topic", "mytopic") + .load("tcp://localhost:1883") + +``` + +## Enable recovering from failures. + +Setting values for option `localStorage` and `clientId` helps in recovering in case of a restart, by restoring the state where it left off before the shutdown. + +```scala +sqlContext.readStream + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") + .option("topic", "mytopic") + .option("localStorage", "/path/to/localdir") + .option("clientId", "some-client-id") + .load("tcp://localhost:1883") + +``` + +### Scala API + +An example, for scala API to count words from incoming message stream. + +```scala + // Create DataFrame representing the stream of input lines from connection to mqtt server + val lines = spark.readStream + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") + .option("topic", topic) + .load(brokerUrl).as[(String, Timestamp)] + + // Split the lines into words + val words = lines.map(_._1).flatMap(_.split(" ")) + + // Generate running word count + val wordCounts = words.groupBy("value").count() + + // Start running the query that prints the running counts to the console + val query = wordCounts.writeStream + .outputMode("complete") + .format("console") + .start() + + query.awaitTermination() + +``` +Please see `MQTTStreamWordCount.scala` for full example. + +### Java API + +An example, for Java API to count words from incoming message stream. + +```java + + // Create DataFrame representing the stream of input lines from connection to mqtt server. + Dataset<String> lines = spark + .readStream() + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") + .option("topic", topic) + .load(brokerUrl).select("value").as(Encoders.STRING()); + + // Split the lines into words + Dataset<String> words = lines.flatMap(new FlatMapFunction<String, String>() { + @Override + public Iterator<String> call(String x) { + return Arrays.asList(x.split(" ")).iterator(); + } + }, Encoders.STRING()); + + // Generate running word count + Dataset<Row> wordCounts = words.groupBy("value").count(); + + // Start running the query that prints the running counts to the console + StreamingQuery query = wordCounts.writeStream() + .outputMode("complete") + .format("console") + .start(); + + query.awaitTermination(); +``` + +Please see `JavaMQTTStreamWordCount.java` for full example. + http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/pom.xml ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/pom.xml b/sql-streaming-mqtt/pom.xml new file mode 100644 index 0000000..9d0d188 --- /dev/null +++ b/sql-streaming-mqtt/pom.xml @@ -0,0 +1,121 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one or more + ~ contributor license agreements. See the NOTICE file distributed with + ~ this work for additional information regarding copyright ownership. + ~ The ASF licenses this file to You under the Apache License, Version 2.0 + ~ (the "License"); you may not use this file except in compliance with + ~ the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.bahir</groupId> + <artifactId>bahir-parent_2.11</artifactId> + <version>2.0.0-SNAPSHOT</version> + <relativePath>../pom.xml</relativePath> + </parent> + + <groupId>org.apache.bahir</groupId> + <artifactId>spark-sql-streaming-mqtt_2.11</artifactId> + <properties> + <sbt.project.name>sql-streaming-mqtt</sbt.project.name> + </properties> + <packaging>jar</packaging> + <name>Apache Bahir - Spark SQL Streaming MQTT</name> + <url>http://bahir.apache.org/</url> + + <dependencies> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-tags_${scala.binary.version}</artifactId> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.eclipse.paho</groupId> + <artifactId>org.eclipse.paho.client.mqttv3</artifactId> + <version>1.1.0</version> + </dependency> + <dependency> + <groupId>org.scalacheck</groupId> + <artifactId>scalacheck_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.activemq</groupId> + <artifactId>activemq-all</artifactId> + <version>5.13.3</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.activemq</groupId> + <artifactId>activemq-mqtt</artifactId> + <version>5.13.3</version> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> + + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-source-plugin</artifactId> + </plugin> + + <!-- Assemble a jar with test dependencies for Python tests --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-assembly-plugin</artifactId> + <executions> + <execution> + <id>test-jar-with-dependencies</id> + <phase>package</phase> + <goals> + <goal>single</goal> + </goals> + <configuration> + <!-- Make sure the file path is same as the sbt build --> + <finalName>spark-streaming-mqtt-test-${project.version}</finalName> + <outputDirectory>${project.build.directory}/scala-${scala.binary.version}/</outputDirectory> + <appendAssemblyId>false</appendAssemblyId> + <!-- Don't publish it since it's only for Python tests --> + <attach>false</attach> + <descriptors> + <descriptor>src/main/assembly/assembly.xml</descriptor> + </descriptors> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> +</project> http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/main/assembly/assembly.xml ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/assembly/assembly.xml b/sql-streaming-mqtt/src/main/assembly/assembly.xml new file mode 100644 index 0000000..c110b01 --- /dev/null +++ b/sql-streaming-mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one or more + ~ contributor license agreements. See the NOTICE file distributed with + ~ this work for additional information regarding copyright ownership. + ~ The ASF licenses this file to You under the Apache License, Version 2.0 + ~ (the "License"); you may not use this file except in compliance with + ~ the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> +<assembly> + <id>test-jar-with-dependencies</id> + <formats> + <format>jar</format> + </formats> + <includeBaseDirectory>false</includeBaseDirectory> + + <fileSets> + <fileSet> + <directory>${project.build.directory}/scala-${scala.binary.version}/test-classes</directory> + <outputDirectory></outputDirectory> + </fileSet> + </fileSets> + + <dependencySets> + <dependencySet> + <useTransitiveDependencies>true</useTransitiveDependencies> + <scope>test</scope> + <unpack>true</unpack> + <excludes> + <exclude>org.apache.hadoop:*:jar</exclude> + <exclude>org.apache.zookeeper:*:jar</exclude> + <exclude>org.apache.avro:*:jar</exclude> + </excludes> + </dependencySet> + </dependencySets> + +</assembly> http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/main/java/org/apache/bahir/sql/streaming/mqtt/examples/JavaMQTTStreamWordCount.java ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/java/org/apache/bahir/sql/streaming/mqtt/examples/JavaMQTTStreamWordCount.java b/sql-streaming-mqtt/src/main/java/org/apache/bahir/sql/streaming/mqtt/examples/JavaMQTTStreamWordCount.java new file mode 100644 index 0000000..6d8935c --- /dev/null +++ b/sql-streaming-mqtt/src/main/java/org/apache/bahir/sql/streaming/mqtt/examples/JavaMQTTStreamWordCount.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt.examples; + +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.streaming.StreamingQuery; + +import java.util.Arrays; +import java.util.Iterator; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from MQTT Server. + * + * Usage: JavaMQTTStreamWordCount <brokerUrl> <topic> + * <brokerUrl> and <topic> describe the MQTT server that Structured Streaming + * would connect to receive data. + * + * To run this on your local machine, a MQTT Server should be up and running. + * + */ +public final class JavaMQTTStreamWordCount { + + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaMQTTStreamWordCount <brokerUrl> <topic>"); + System.exit(1); + } + + String brokerUrl = args[0]; + String topic = args[1]; + + SparkSession spark = SparkSession + .builder() + .appName("JavaMQTTStreamWordCount") + .master("local[4]") + .getOrCreate(); + + // Create DataFrame representing the stream of input lines from connection to mqtt server + Dataset<String> lines = spark + .readStream() + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") + .option("topic", topic) + .load(brokerUrl).select("value").as(Encoders.STRING()); + + // Split the lines into words + Dataset<String> words = lines.flatMap(new FlatMapFunction<String, String>() { + @Override + public Iterator<String> call(String x) { + return Arrays.asList(x.split(" ")).iterator(); + } + }, Encoders.STRING()); + + // Generate running word count + Dataset<Row> wordCounts = words.groupBy("value").count(); + + // Start running the query that prints the running counts to the console + StreamingQuery query = wordCounts.writeStream() + .outputMode("complete") + .format("console") + .start(); + + query.awaitTermination(); + } +} http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..1389e16 --- /dev/null +++ b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider \ No newline at end of file http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala new file mode 100644 index 0000000..471886a --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.sql.Timestamp +import java.text.SimpleDateFormat +import java.util.Calendar +import java.util.concurrent.CountDownLatch + +import scala.collection.concurrent.TrieMap +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success, Try} + +import org.apache.bahir.utils.Logging +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.{MemoryPersistence, MqttDefaultFilePersistence} + +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Source} +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + + +object MQTTStreamConstants { + + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + + val SCHEMA_DEFAULT = StructType(StructField("value", StringType) + :: StructField("timestamp", TimestampType) :: Nil) +} + +class MQTTTextStreamSource(brokerUrl: String, persistence: MqttClientPersistence, + topic: String, clientId: String, messageParser: Array[Byte] => (String, Timestamp), + sqlContext: SQLContext) extends Source with Logging { + + override def schema: StructType = MQTTStreamConstants.SCHEMA_DEFAULT + + private val store = new LocalMessageStore(persistence, sqlContext.sparkContext.getConf) + + private val messages = new TrieMap[Int, (String, Timestamp)] + + private val initLock = new CountDownLatch(1) + + private var offset = 0 + + private var client: MqttClient = _ + + private def fetchLastProcessedOffset(): Int = { + Try(store.maxProcessedOffset) match { + case Success(x) => + log.info(s"Recovering from last stored offset $x") + x + case Failure(e) => 0 + } + } + + initialize() + private def initialize(): Unit = { + + client = new MqttClient(brokerUrl, clientId, persistence) + val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions() + mqttConnectOptions.setAutomaticReconnect(true) + // This is required to support recovery. TODO: configurable ? + mqttConnectOptions.setCleanSession(false) + + val callback = new MqttCallbackExtended() { + + override def messageArrived(topic_ : String, message: MqttMessage): Unit = synchronized { + initLock.await() // Wait for initialization to complete. + val temp = offset + 1 + messages.put(temp, messageParser(message.getPayload)) + offset = temp + log.trace(s"Message arrived, $topic_ $message") + } + + override def deliveryComplete(token: IMqttDeliveryToken): Unit = { + } + + override def connectionLost(cause: Throwable): Unit = { + log.warn("Connection to mqtt server lost.", cause) + } + + override def connectComplete(reconnect: Boolean, serverURI: String): Unit = { + log.info(s"Connect complete $serverURI. Is it a reconnect?: $reconnect") + } + } + client.setCallback(callback) + client.connect(mqttConnectOptions) + client.subscribe(topic) + // It is not possible to initialize offset without `client.connect` + offset = fetchLastProcessedOffset() + initLock.countDown() // Release. + } + + /** Stop this source and free any resources it has allocated. */ + override def stop(): Unit = { + client.disconnect() + persistence.close() + client.close() + } + + /** Returns the maximum available offset for this source. */ + override def getOffset: Option[Offset] = { + if (offset == 0) { + None + } else { + Some(LongOffset(offset)) + } + } + + /** + * Returns the data that is between the offsets (`start`, `end`]. When `start` is `None` then + * the batch should begin with the first available record. This method must always return the + * same data for a particular `start` and `end` pair. + */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { + val startIndex = start.getOrElse(LongOffset(0L)).asInstanceOf[LongOffset].offset.toInt + val endIndex = end.asInstanceOf[LongOffset].offset.toInt + val data: ArrayBuffer[(String, Timestamp)] = ArrayBuffer.empty + // Move consumed messages to persistent store. + (startIndex + 1 to endIndex).foreach { id => + val element: (String, Timestamp) = messages.getOrElse(id, store.retrieve(id)) + data += element + store.store(id, element) + messages.remove(id, element) + } + log.trace(s"Get Batch invoked, ${data.mkString}") + import sqlContext.implicits._ + data.toDF("value", "timestamp") + } + +} + +class MQTTStreamSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { + + override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType], + providerName: String, parameters: Map[String, String]): (String, StructType) = { + ("mqtt", MQTTStreamConstants.SCHEMA_DEFAULT) + } + + override def createSource(sqlContext: SQLContext, metadataPath: String, + schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = { + + def e(s: String) = new IllegalArgumentException(s) + + val brokerUrl: String = parameters.getOrElse("brokerUrl", parameters.getOrElse("path", + throw e("Please provide a `brokerUrl` by specifying path or .options(\"brokerUrl\",...)"))) + + + val persistence: MqttClientPersistence = parameters.get("persistence") match { + case Some("memory") => new MemoryPersistence() + case _ => val localStorage: Option[String] = parameters.get("localStorage") + localStorage match { + case Some(x) => new MqttDefaultFilePersistence(x) + case None => new MqttDefaultFilePersistence() + } + } + + val messageParserWithTimeStamp = (x: Array[Byte]) => (new String(x), Timestamp.valueOf( + MQTTStreamConstants.DATE_FORMAT.format(Calendar.getInstance().getTime))) + + // if default is subscribe everything, it leads to getting lot unwanted system messages. + val topic: String = parameters.getOrElse("topic", + throw e("Please specify a topic, by .options(\"topic\",...)")) + + val clientId: String = parameters.getOrElse("clientId", { + log.warn("If `clientId` is not set, a random value is picked up." + + "\nRecovering from failure is not supported in such a case.") + MqttClient.generateClientId()}) + + new MQTTTextStreamSource(brokerUrl, persistence, topic, clientId, + messageParserWithTimeStamp, sqlContext) + } + + override def shortName(): String = "mqtt" +} http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala new file mode 100644 index 0000000..e8e0f7d --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.bahir.sql.streaming.mqtt + +import java.nio.ByteBuffer +import java.util + +import scala.reflect.ClassTag + +import org.apache.bahir.utils.Logging +import org.eclipse.paho.client.mqttv3.{MqttClientPersistence, MqttPersistable, MqttPersistenceException} + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerInstance} + + +/** A message store for MQTT stream source for SQL Streaming. */ +trait MessageStore { + + /** Store a single id and corresponding serialized message */ + def store[T: ClassTag](id: Int, message: T): Boolean + + /** Retrieve messages corresponding to certain offset range */ + def retrieve[T: ClassTag](start: Int, end: Int): Seq[T] + + /** Retrieve message corresponding to a given id. */ + def retrieve[T: ClassTag](id: Int): T + + /** Highest offset we have stored */ + def maxProcessedOffset: Int + +} + +private[mqtt] class MqttPersistableData(bytes: Array[Byte]) extends MqttPersistable { + + override def getHeaderLength: Int = bytes.length + + override def getHeaderOffset: Int = 0 + + override def getPayloadOffset: Int = 0 + + override def getPayloadBytes: Array[Byte] = null + + override def getHeaderBytes: Array[Byte] = bytes + + override def getPayloadLength: Int = 0 +} + +/** + * A message store to persist messages received. This is not intended to be thread safe. + * It uses `MqttDefaultFilePersistence` for storing messages on disk locally on the client. + */ +private[mqtt] class LocalMessageStore(val persistentStore: MqttClientPersistence, + val serializer: Serializer) extends MessageStore with Logging { + + val classLoader = Thread.currentThread.getContextClassLoader + + def this(persistentStore: MqttClientPersistence, conf: SparkConf) = + this(persistentStore, new JavaSerializer(conf)) + + val serializerInstance: SerializerInstance = serializer.newInstance() + + private def get(id: Int) = { + persistentStore.get(id.toString).getHeaderBytes + } + + import scala.collection.JavaConverters._ + + def maxProcessedOffset: Int = { + val keys: util.Enumeration[_] = persistentStore.keys() + keys.asScala.map(x => x.toString.toInt).max + } + + /** Store a single id and corresponding serialized message */ + override def store[T: ClassTag](id: Int, message: T): Boolean = { + val bytes: Array[Byte] = serializerInstance.serialize(message).array() + try { + persistentStore.put(id.toString, new MqttPersistableData(bytes)) + true + } catch { + case e: MqttPersistenceException => log.warn(s"Failed to store message Id: $id", e) + false + } + } + + /** Retrieve messages corresponding to certain offset range */ + override def retrieve[T: ClassTag](start: Int, end: Int): Seq[T] = { + (start until end).map(x => retrieve(x)) + } + + /** Retrieve message corresponding to a given id. */ + override def retrieve[T: ClassTag](id: Int): T = { + serializerInstance.deserialize(ByteBuffer.wrap(get(id)), classLoader) + } + +} http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/examples/MQTTStreamWordCount.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/examples/MQTTStreamWordCount.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/examples/MQTTStreamWordCount.scala new file mode 100644 index 0000000..c792858 --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/examples/MQTTStreamWordCount.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt.examples + +import java.sql.Timestamp + +import org.apache.spark.sql.SparkSession + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from MQTT Server. + * + * Usage: MQTTStreamWordCount <brokerUrl> <topic> + * <brokerUrl> and <topic> describe the MQTT server that Structured Streaming + * would connect to receive data. + * + * To run this on your local machine, a MQTT Server should be up and running. + * + */ +object MQTTStreamWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: MQTTStreamWordCount <brokerUrl> <topic>") // scalastyle:off println + System.exit(1) + } + + val brokerUrl = args(0) + val topic = args(1) + + val spark = SparkSession + .builder + .appName("MQTTStreamWordCount") + .master("local[4]") + .getOrCreate() + + import spark.implicits._ + + // Create DataFrame representing the stream of input lines from connection to mqtt server + val lines = spark.readStream + .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") + .option("topic", topic) + .load(brokerUrl).as[(String, Timestamp)] + + // Split the lines into words + val words = lines.map(_._1).flatMap(_.split(" ")) + + // Generate running word count + val wordCounts = words.groupBy("value").count() + + // Start running the query that prints the running counts to the console + val query = wordCounts.writeStream + .outputMode("complete") + .format("console") + .start() + + query.awaitTermination() + } +} + http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/main/scala/org/apache/bahir/utils/BahirUtils.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/utils/BahirUtils.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/utils/BahirUtils.scala new file mode 100644 index 0000000..3d27b06 --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/utils/BahirUtils.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.utils + +import java.io.{File, IOException} +import java.nio.file.{Files, FileVisitResult, Path, SimpleFileVisitor} +import java.nio.file.attribute.BasicFileAttributes + +object BahirUtils extends Logging { + + def recursiveDeleteDir(dir: File): Path = { + Files.walkFileTree(dir.toPath, new SimpleFileVisitor[Path]() { + override def visitFile(file: Path, attrs: BasicFileAttributes): FileVisitResult = { + try { + Files.delete(file) + } catch { + case t: Throwable => log.warn("Failed to delete", t) + } + FileVisitResult.CONTINUE + } + + override def postVisitDirectory(dir: Path, exc: IOException): FileVisitResult = { + try { + Files.delete(dir) + } catch { + case t: Throwable => log.warn("Failed to delete", t) + } + FileVisitResult.CONTINUE + } + }) + } + +} http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/main/scala/org/apache/bahir/utils/Logging.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/utils/Logging.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/utils/Logging.scala new file mode 100644 index 0000000..cbe97e9 --- /dev/null +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/utils/Logging.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.utils + +import org.slf4j.LoggerFactory + + +trait Logging { + final val log = LoggerFactory.getLogger(this.getClass.getName.stripSuffix("$")) +} http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/test/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/resources/log4j.properties b/sql-streaming-mqtt/src/test/resources/log4j.properties new file mode 100644 index 0000000..3706a6e --- /dev/null +++ b/sql-streaming-mqtt/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark_project.jetty=WARN http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala new file mode 100644 index 0000000..44da041 --- /dev/null +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.io.File + +import org.apache.bahir.utils.BahirUtils +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer + +class LocalMessageStoreSuite extends SparkFunSuite with BeforeAndAfter { + + private val testData = Seq(1, 2, 3, 4, 5, 6) + private val javaSerializer: JavaSerializer = new JavaSerializer(new SparkConf()) + + private val serializerInstance = javaSerializer.newInstance() + private val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test2/") + private val persistence: MqttDefaultFilePersistence = + new MqttDefaultFilePersistence(tempDir.getAbsolutePath) + + private val store = new LocalMessageStore(persistence, javaSerializer) + + before { + tempDir.mkdirs() + tempDir.deleteOnExit() + persistence.open("temp", "tcp://dummy-url:0000") + } + + after { + persistence.clear() + persistence.close() + BahirUtils.recursiveDeleteDir(tempDir) + } + + test("serialize and deserialize") { + val serialized = serializerInstance.serialize(testData) + val deserialized: Seq[Int] = serializerInstance + .deserialize(serialized).asInstanceOf[Seq[Int]] + assert(testData === deserialized) + } + + test("Store and retreive") { + store.store(1, testData) + val result: Seq[Int] = store.retrieve(1) + assert(testData === result) + } + + test("Max offset stored") { + store.store(1, testData) + store.store(10, testData) + val offset: Int = store.maxProcessedOffset + assert(offset == 10) + } + +} http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala new file mode 100644 index 0000000..f6f5ff6 --- /dev/null +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.io.File +import java.sql.Timestamp + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.Future + +import org.apache.bahir.utils.BahirUtils +import org.eclipse.paho.client.mqttv3.MqttException +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.streaming.LongOffset + +class MQTTStreamSourceSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { + + protected var mqttTestUtils: MQTTTestUtils = _ + protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test/") + + before { + mqttTestUtils = new MQTTTestUtils(tempDir) + mqttTestUtils.setup() + tempDir.mkdirs() + } + + after { + mqttTestUtils.teardown() + BahirUtils.recursiveDeleteDir(tempDir) + } + + protected val tmpDir: String = tempDir.getAbsolutePath + + protected def createStreamingDataframe(dir: String = tmpDir): (SQLContext, DataFrame) = { + + val sqlContext: SQLContext = new SQLContext(sc) + + sqlContext.setConf("spark.sql.streaming.checkpointLocation", tmpDir) + + val dataFrame: DataFrame = + sqlContext.readStream.format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider") + .option("topic", "test").option("localStorage", dir).option("clientId", "clientId") + .load("tcp://" + mqttTestUtils.brokerUri) + (sqlContext, dataFrame) + } + +} + +class BasicMQTTSourceSuite extends MQTTStreamSourceSuite { + + private def writeStreamResults(sqlContext: SQLContext, + dataFrame: DataFrame, waitDuration: Long): Boolean = { + import sqlContext.implicits._ + dataFrame.as[(String, Timestamp)].writeStream.format("parquet").start(s"$tmpDir/t.parquet") + .awaitTermination(waitDuration) + } + + private def readBackStreamingResults(sqlContext: SQLContext): mutable.Buffer[String] = { + import sqlContext.implicits._ + val asList = + sqlContext.read.schema(MQTTStreamConstants.SCHEMA_DEFAULT) + .parquet(s"$tmpDir/t.parquet").as[(String, Timestamp)].map(_._1) + .collectAsList().asScala + asList + } + + test("basic usage") { + + val sendMessage = "MQTT is a message queue." + + mqttTestUtils.publishData("test", sendMessage) + + val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe() + + writeStreamResults(sqlContext, dataFrame, 5000) + + val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext) + + assert(resultBuffer.size == 1) + assert(resultBuffer.head == sendMessage) + } + + test("Send and receive 100 messages.") { + + val sendMessage = "MQTT is a message queue." + + import scala.concurrent.ExecutionContext.Implicits.global + + val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe() + + Future { + Thread.sleep(2000) + mqttTestUtils.publishData("test", sendMessage, 100) + } + + writeStreamResults(sqlContext, dataFrame, 10000) + + val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext) + + assert(resultBuffer.size == 100) + assert(resultBuffer.head == sendMessage) + } + + test("no server up") { + val provider = new MQTTStreamSourceProvider + val sqlContext: SQLContext = new SQLContext(sc) + val parameters = Map("brokerUrl" -> "tcp://localhost:1883", "topic" -> "test", + "localStorage" -> tmpDir) + intercept[MqttException] { + provider.createSource(sqlContext, "", None, "", parameters) + } + } + + test("params not provided.") { + val provider = new MQTTStreamSourceProvider + val sqlContext: SQLContext = new SQLContext(sc) + val parameters = Map("brokerUrl" -> mqttTestUtils.brokerUri, + "localStorage" -> tmpDir) + intercept[IllegalArgumentException] { + provider.createSource(sqlContext, "", None, "", parameters) + } + intercept[IllegalArgumentException] { + provider.createSource(sqlContext, "", None, "", Map()) + } + } + + test("Recovering offset from the last processed offset.") { + val sendMessage = "MQTT is a message queue." + + import scala.concurrent.ExecutionContext.Implicits.global + + val (sqlContext: SQLContext, dataFrame: DataFrame) = + createStreamingDataframe() + + Future { + Thread.sleep(2000) + mqttTestUtils.publishData("test", sendMessage, 100) + } + + writeStreamResults(sqlContext, dataFrame, 10000) + // On restarting the source with same params, it should begin from the offset - the + // previously running stream left off. + val provider = new MQTTStreamSourceProvider + val parameters = Map("brokerUrl" -> ("tcp://" + mqttTestUtils.brokerUri), "topic" -> "test", + "localStorage" -> tmpDir, "clientId" -> "clientId") + val offset: Long = provider.createSource(sqlContext, "", None, "", parameters) + .getOffset.get.asInstanceOf[LongOffset].offset + assert(offset == 100L) + } + +} + +class StressTestMQTTSource extends MQTTStreamSourceSuite { + + // Run with -Xmx1024m + ignore("Send and receive messages of size 250MB.") { + + val freeMemory: Long = Runtime.getRuntime.freeMemory() + + log.info(s"Available memory before test run is ${freeMemory / (1024 * 1024)}MB.") + + val noOfMsgs = (250 * 1024 * 1024) / (500 * 1024) // 512 + + val messageBuilder = new StringBuilder() + for (i <- 0 until (500 * 1024)) yield messageBuilder.append(((i % 26) + 65).toChar) + val sendMessage = messageBuilder.toString() // each message is 50 KB + + val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataframe() + + import scala.concurrent.ExecutionContext.Implicits.global + Future { + Thread.sleep(2000) + mqttTestUtils.publishData("test", sendMessage, noOfMsgs.toInt) + } + + import sqlContext.implicits._ + + dataFrame.as[(String, Timestamp)].writeStream + .format("parquet") + .start(s"$tmpDir/t.parquet") + .awaitTermination(25000) + + val messageCount = + sqlContext.read.schema(MQTTStreamConstants.SCHEMA_DEFAULT) + .parquet(s"$tmpDir/t.parquet").as[(String, Timestamp)].map(_._1) + .count() + assert(messageCount == noOfMsgs) + } +} http://git-wip-us.apache.org/repos/asf/bahir/blob/c98dd0fe/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 0000000..bebeeef --- /dev/null +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.io.File +import java.net.{ServerSocket, URI} + +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.bahir.utils.Logging +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +class MQTTTestUtils(tempDir: File, port: Int = 0) extends Logging { + + private val persistenceDir = tempDir.getAbsolutePath + private val brokerHost = "localhost" + private val brokerPort: Int = if (port == 0) findFreePort() else port + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + private def findFreePort() = { + val s = new ServerSocket(0) + val port: Int = s.getLocalPort + s.close() + port + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(tempDir) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + } + + def publishData(topic: String, data: String, N: Int = 1): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + for (i <- 0 until N) { + try { + Thread.sleep(20) + val message = new MqttMessage(data.getBytes()) + message.setQos(2) + message.setRetained(true) + msgTopic.publish(message) + } catch { + case e: MqttException => + // wait for Spark sql streaming to consume something from the message queue + Thread.sleep(50) + log.warn(s"publish failed", e) + case x: Throwable => log.warn(s"publish failed $x") + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +}
