Repository: bahir
Updated Branches:
  refs/heads/master 3a211a74c -> b3902bac6


[BAHIR-164][BAHIR-165] Port Mqtt sql source to datasource v2 API

Migrating Mqtt spark structured streaming connector to DatasourceV2 API.

Closes #65


Project: http://git-wip-us.apache.org/repos/asf/bahir/repo
Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/b3902bac
Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/b3902bac
Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/b3902bac

Branch: refs/heads/master
Commit: b3902bac67edc2134bcc2c755fadc5c60c8ae01c
Parents: 3a211a7
Author: Prashant Sharma <prash...@in.ibm.com>
Authored: Fri Apr 27 12:39:35 2018 +0530
Committer: Luciano Resende <lrese...@apache.org>
Committed: Wed Nov 7 19:11:18 2018 -0800

----------------------------------------------------------------------
 pom.xml                                         |  18 +-
 .../streaming/akka/AkkaStreamSourceSuite.scala  |   2 +-
 sql-streaming-mqtt/README.md                    |  58 +++-
 .../streaming/mqtt/JavaMQTTStreamWordCount.java |   2 +-
 .../streaming/mqtt/MQTTStreamWordCount.scala    |   6 +-
 .../bahir/sql/streaming/mqtt/LongOffset.scala   |  54 ++++
 .../sql/streaming/mqtt/MQTTStreamSource.scala   | 284 ++++++++++++-------
 .../bahir/sql/streaming/mqtt/MessageStore.scala |  90 ++++--
 .../src/test/bin/test-BAHIR-83.sh               |  24 ++
 .../streaming/mqtt/LocalMessageStoreSuite.scala |   9 +-
 .../streaming/mqtt/MQTTStreamSourceSuite.scala  | 154 +++++-----
 .../sql/streaming/mqtt/MQTTTestUtils.scala      |  14 +-
 12 files changed, 462 insertions(+), 253 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 1edd641..8282346 100644
--- a/pom.xml
+++ b/pom.xml
@@ -77,7 +77,7 @@
   <modules>
     <module>sql-cloudant</module>
     <module>streaming-akka</module>
-    <module>sql-streaming-akka</module>
+    <!-- <module>sql-streaming-akka</module> Disabling akka sql module, until 
it is updated to run with datasource v2 API. -->
     <module>streaming-mqtt</module>
     <module>sql-streaming-mqtt</module>
     <module>streaming-twitter</module>
@@ -99,7 +99,7 @@
     <log4j.version>1.2.17</log4j.version>
 
     <!-- Spark version -->
-    <spark.version>2.2.2</spark.version>
+    <spark.version>2.3.0</spark.version>
 
     <!-- MQTT Client -->
     <mqtt.paho.client>1.1.0</mqtt.paho.client>
@@ -348,13 +348,13 @@
       <dependency>
         <groupId>org.scalatest</groupId>
         <artifactId>scalatest_${scala.binary.version}</artifactId>
-        <version>2.2.6</version>
+        <version>3.0.3</version>
         <scope>test</scope>
       </dependency>
       <dependency>
         <groupId>org.scalacheck</groupId>
         <artifactId>scalacheck_${scala.binary.version}</artifactId>
-        <version>1.12.5</version> <!-- 1.13.0 appears incompatible with 
scalatest 2.2.6 -->
+        <version>1.13.5</version>
         <scope>test</scope>
       </dependency>
 
@@ -407,7 +407,7 @@
         <plugin>
           <groupId>org.apache.maven.plugins</groupId>
           <artifactId>maven-enforcer-plugin</artifactId>
-          <version>1.4.1</version>
+          <version>3.0.0-M1</version>
           <executions>
             <execution>
               <id>enforce-versions</id>
@@ -433,6 +433,7 @@
                       -->
                       <exclude>org.jboss.netty</exclude>
                       <exclude>org.codehaus.groovy</exclude>
+                      <exclude>*:*_2.10</exclude>
                     </excludes>
                     <searchTransitive>true</searchTransitive>
                   </bannedDependencies>
@@ -482,7 +483,8 @@
         <plugin>
           <groupId>net.alchim31.maven</groupId>
           <artifactId>scala-maven-plugin</artifactId>
-          <version>3.3.1</version>
+          <!-- 3.3.1 won't work with zinc; fails to find javac from java.home 
-->
+          <version>3.2.2</version>
           <executions>
             <execution>
               <id>eclipse-add-source</id>
@@ -557,7 +559,7 @@
         <plugin>
           <groupId>org.apache.maven.plugins</groupId>
           <artifactId>maven-surefire-plugin</artifactId>
-          <version>2.19.1</version>
+          <version>2.20.1</version>
           <!-- Note config is repeated in scalatest config -->
           <configuration>
             <includes>
@@ -567,7 +569,7 @@
               <include>**/*Suite.java</include>
             </includes>
             
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
-            <argLine>-Xmx3g -Xss4096k 
-XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine>
+            <argLine>-ea -Xmx3g -Xss4m 
-XX:ReservedCodeCacheSize=${CodeCacheSize}</argLine>
             <environmentVariables>
               <!--
                 Setting SPARK_DIST_CLASSPATH is a simple way to make sure any 
child processes

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
 
b/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
index 5e9b86e..cdf629b 100644
--- 
a/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
+++ 
b/sql-streaming-akka/src/test/scala/org/apache/bahir/sql/streaming/akka/AkkaStreamSourceSuite.scala
@@ -155,7 +155,7 @@ class StressTestAkkaSource extends AkkaStreamSourceSuite {
 
   // Run with -Xmx1024m
   // Default allowed payload size sent to an akka actor is 128000 bytes.
-  test("Send & Receive messages of size 128000 bytes.") {
+  ignore("Send & Receive messages of size 128000 bytes.") {
 
     val freeMemory: Long = Runtime.getRuntime.freeMemory()
 

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/README.md
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/README.md b/sql-streaming-mqtt/README.md
index 2cfbe0f..b7f0602 100644
--- a/sql-streaming-mqtt/README.md
+++ b/sql-streaming-mqtt/README.md
@@ -59,7 +59,9 @@ This source uses [Eclipse Paho Java 
Client](https://eclipse.org/paho/clients/jav
  * `connectionTimeout` Sets the connection timeout, a value of 0 is 
interpretted as wait until client connects. See 
`MqttConnectOptions.setConnectionTimeout` for more information.
  * `keepAlive` Same as `MqttConnectOptions.setKeepAliveInterval`.
  * `mqttVersion` Same as `MqttConnectOptions.setMqttVersion`.
-
+ * `maxInflight` Same as `MqttConnectOptions.setMaxInflight`
+ * `autoReconnect` Same as `MqttConnectOptions.setAutomaticReconnect`
+ 
 ### Scala API
 
 An example, for scala API to count words from incoming message stream. 
@@ -68,7 +70,7 @@ An example, for scala API to count words from incoming 
message stream.
     val lines = spark.readStream
       .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
       .option("topic", topic)
-      .load(brokerUrl).as[(String, Timestamp)]
+      .load(brokerUrl).selectExpr("CAST(payload AS STRING)").as[String]
 
     // Split the lines into words
     val words = lines.map(_._1).flatMap(_.split(" "))
@@ -95,7 +97,8 @@ An example, for Java API to count words from incoming message 
stream.
             .readStream()
             
.format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
             .option("topic", topic)
-            .load(brokerUrl).select("value").as(Encoders.STRING());
+            .load(brokerUrl)
+            .selectExpr("CAST(payload AS STRING)").as(Encoders.STRING());
 
     // Split the lines into words
     Dataset<String> words = lines.flatMap(new FlatMapFunction<String, 
String>() {
@@ -118,3 +121,52 @@ An example, for Java API to count words from incoming 
message stream.
 
 Please see `JavaMQTTStreamWordCount.java` for full example.
 
+## Best Practices.
+
+1. Turn Mqtt into a more reliable messaging service. 
+
+> *MQTT is a machine-to-machine (M2M)/"Internet of Things" connectivity 
protocol. It was designed as an extremely lightweight publish/subscribe 
messaging transport.*
+
+The design of Mqtt and the purpose it serves goes well together, but often in 
an application it is of utmost value to have reliability. Since mqtt is not a 
distributed message queue and thus does not offer the highest level of 
reliability features. It should be redirected via a kafka message queue to take 
advantage of a distributed message queue. In fact, using a kafka message queue 
offers a lot of possibilities including a single kafka topic subscribed to 
several mqtt sources and even a single mqtt stream publishing to multiple kafka 
topics. Kafka is a reliable and scalable message queue.
+
+2. Often the message payload is not of the default character encoding or 
contains binary that needs to be parsed using a particular parser. In such 
cases, spark mqtt payload should be processed using the external parser. For 
example:
+
+ * Scala API example:
+```scala
+    // Create DataFrame representing the stream of binary messages
+    val lines = spark.readStream
+      .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
+      .option("topic", topic)
+      .load(brokerUrl).select("payload").as[Array[Byte]].map(externalParser(_))
+```
+
+ * Java API example
+```java
+        // Create DataFrame representing the stream of binary messages
+        Dataset<byte[]> lines = spark
+                .readStream()
+                
.format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
+                .option("topic", topic)
+                .load(brokerUrl).selectExpr("CAST(payload AS 
BINARY)").as(Encoders.BINARY());
+
+        // Split the lines into words
+        Dataset<String> words = lines.map(new MapFunction<byte[], String>() {
+            @Override
+            public String call(byte[] bytes) throws Exception {
+                return new String(bytes); // Plug in external parser here.
+            }
+        }, Encoders.STRING()).flatMap(new FlatMapFunction<String, String>() {
+            @Override
+            public Iterator<String> call(String x) {
+                return Arrays.asList(x.split(" ")).iterator();
+            }
+        }, Encoders.STRING());
+
+```
+
+3. What is the solution for a situation when there are a large number of 
varied mqtt sources, each with different schema and throughput characteristics.
+
+Generally, one would create a lot of streaming pipelines to solve this 
problem. This would either require a very sophisticated scheduling setup or 
will waste a lot of resources, as it is not certain which stream is using more 
amount of data.
+
+The general solution is both less optimum and is more cumbersome to operate, 
with multiple moving parts incurs a high maintenance overall. As an 
alternative, in this situation, one can setup a single topic kafka-spark 
stream, where message from each of the varied stream contains a unique tag 
separating one from other streams. This way at the processing end, one can 
distinguish the message from one another and apply the right kind of decoding 
and processing. Similarly while storing, each message can be distinguished from 
others by a tag that distinguishes.
+

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
----------------------------------------------------------------------
diff --git 
a/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
 
b/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
index 519d9a0..4e87c99 100644
--- 
a/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
+++ 
b/sql-streaming-mqtt/examples/src/main/java/org/apache/bahir/examples/sql/streaming/mqtt/JavaMQTTStreamWordCount.java
@@ -71,7 +71,7 @@ public final class JavaMQTTStreamWordCount {
                 .readStream()
                 
.format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
                 .option("topic", topic)
-                .load(brokerUrl).select("value").as(Encoders.STRING());
+                .load(brokerUrl).selectExpr("CAST(payload AS 
STRING)").as(Encoders.STRING());
 
         // Split the lines into words
         Dataset<String> words = lines.flatMap(new FlatMapFunction<String, 
String>() {

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
----------------------------------------------------------------------
diff --git 
a/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
 
b/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
index 237a8fa..ee7de22 100644
--- 
a/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
+++ 
b/sql-streaming-mqtt/examples/src/main/scala/org/apache/bahir/examples/sql/streaming/mqtt/MQTTStreamWordCount.scala
@@ -52,11 +52,11 @@ object MQTTStreamWordCount  {
     // Create DataFrame representing the stream of input lines from connection 
to mqtt server
     val lines = spark.readStream
       .format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
-      .option("topic", topic)
-      .load(brokerUrl).as[(String, Timestamp)]
+      .option("topic", topic).option("persistence", "memory")
+      .load(brokerUrl).selectExpr("CAST(payload AS STRING)").as[String]
 
     // Split the lines into words
-    val words = lines.map(_._1).flatMap(_.split(" "))
+    val words = lines.flatMap(_.split(" "))
 
     // Generate running word count
     val wordCounts = words.groupBy("value").count()

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/LongOffset.scala
----------------------------------------------------------------------
diff --git 
a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/LongOffset.scala
 
b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/LongOffset.scala
new file mode 100644
index 0000000..345b576
--- /dev/null
+++ 
b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/LongOffset.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.bahir.sql.streaming.mqtt
+
+import org.apache.spark.sql.execution.streaming.Offset
+import org.apache.spark.sql.execution.streaming.SerializedOffset
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
+
+/**
+ * A simple offset for sources that produce a single linear stream of data.
+ */
+case class LongOffset(offset: Long) extends OffsetV2 {
+
+  override val json = offset.toString
+
+  def +(increment: Long): LongOffset = new LongOffset(offset + increment)
+  def -(decrement: Long): LongOffset = new LongOffset(offset - decrement)
+}
+
+object LongOffset {
+
+  /**
+   * LongOffset factory from serialized offset.
+   *
+   * @return new LongOffset
+   */
+  def apply(offset: SerializedOffset) : LongOffset = new 
LongOffset(offset.json.toLong)
+
+  /**
+   * Convert generic Offset to LongOffset if possible.
+   *
+   * @return converted LongOffset
+   */
+  def convert(offset: Offset): Option[LongOffset] = offset match {
+    case lo: LongOffset => Some(lo)
+    case so: SerializedOffset => Some(LongOffset(so))
+    case _ => None
+  }
+}

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
----------------------------------------------------------------------
diff --git 
a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
 
b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
index 1739ff3..2f75ee2 100644
--- 
a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
+++ 
b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala
@@ -20,20 +20,23 @@ package org.apache.bahir.sql.streaming.mqtt
 import java.nio.charset.Charset
 import java.sql.Timestamp
 import java.text.SimpleDateFormat
-import java.util.Calendar
-import java.util.concurrent.CountDownLatch
+import java.util.{Calendar, Optional}
+import javax.annotation.concurrent.GuardedBy
 
+import scala.collection.JavaConverters._
 import scala.collection.concurrent.TrieMap
-import scala.collection.mutable.ArrayBuffer
-import scala.util.{Failure, Success, Try}
+import scala.collection.immutable.IndexedSeq
+import scala.collection.mutable.ListBuffer
 
 import org.eclipse.paho.client.mqttv3._
 import org.eclipse.paho.client.mqttv3.persist.{MemoryPersistence, 
MqttDefaultFilePersistence}
 
-import org.apache.spark.sql.{DataFrame, SQLContext}
-import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Source}
-import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
-import org.apache.spark.sql.types.{StringType, StructField, StructType, 
TimestampType}
+import org.apache.spark.sql._
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, 
MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, 
Offset => OffsetV2}
+import org.apache.spark.sql.types._
 
 import org.apache.bahir.utils.Logging
 
@@ -42,15 +45,38 @@ object MQTTStreamConstants {
 
   val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
 
-  val SCHEMA_DEFAULT = StructType(StructField("value", StringType)
-    :: StructField("timestamp", TimestampType) :: Nil)
+  val SCHEMA_DEFAULT = StructType(StructField("id", IntegerType) :: 
StructField("topic",
+    StringType):: StructField("payload", BinaryType) :: 
StructField("timestamp", TimestampType) ::
+    Nil)
 }
 
+class MQTTMessage(m: MqttMessage, val topic: String) extends Serializable {
+
+  // TODO: make it configurable.
+  val timestamp: Timestamp = Timestamp.valueOf(
+    MQTTStreamConstants.DATE_FORMAT.format(Calendar.getInstance().getTime))
+  val duplicate = m.isDuplicate
+  val retained = m.isRetained
+  val qos = m.getQos
+  val id: Int = m.getId
+
+  val payload: Array[Byte] = m.getPayload
+
+  override def toString(): String = {
+    s"""MQTTMessage.
+       |Topic: ${this.topic}
+       |MessageID: ${this.id}
+       |QoS: ${this.qos}
+       |Payload: ${this.payload}
+       |Payload as string: ${new String(this.payload, 
Charset.defaultCharset())}
+       |isRetained: ${this.retained}
+       |isDuplicate: ${this.duplicate}
+       |TimeStamp: ${this.timestamp}
+     """.stripMargin
+  }
+}
 /**
- * A Text based mqtt stream source, it interprets the payload of each incoming 
message by converting
- * the bytes to String using Charset.defaultCharset as charset. Each value is 
associated with a
- * timestamp of arrival of the message on the source. It can be used to 
operate a window on the
- * incoming stream.
+ * A mqtt stream source.
  *
  * @param brokerUrl url MqttClient connects to.
  * @param persistence an instance of MqttClientPersistence. By default it is 
used for storing
@@ -59,53 +85,49 @@ object MQTTStreamConstants {
  * @param topic topic MqttClient subscribes to.
  * @param clientId clientId, this client is assoicated with. Provide the same 
value to recover
  *                 a stopped client.
- * @param messageParser parsing logic for processing incoming messages from 
Mqtt Server.
- * @param sqlContext Spark provided, SqlContext.
  * @param mqttConnectOptions an instance of MqttConnectOptions for this Source.
  * @param qos the maximum quality of service to subscribe each topic 
at.Messages published at
  *            a lower quality of service will be received at the published 
QoS. Messages
  *            published at a higher quality of service will be received using 
the QoS specified
  *            on the subscribe.
  */
-class MQTTTextStreamSource(brokerUrl: String, persistence: 
MqttClientPersistence,
-    topic: String, clientId: String, messageParser: Array[Byte] => (String, 
Timestamp),
-    sqlContext: SQLContext, mqttConnectOptions: MqttConnectOptions, qos: Int)
-  extends Source with Logging {
+class MQTTStreamSource(options: DataSourceOptions, brokerUrl: String, 
persistence:
+    MqttClientPersistence, topic: String, clientId: String,
+    mqttConnectOptions: MqttConnectOptions, qos: Int)
+  extends MicroBatchReader with Logging {
+
+  private var startOffset: OffsetV2 = _
+  private var endOffset: OffsetV2 = _
 
-  override def schema: StructType = MQTTStreamConstants.SCHEMA_DEFAULT
+  /* Older than last N messages, will not be checked for redelivery. */
+  val backLog = options.getInt("autopruning.backlog", 500)
 
-  private val store = new LocalMessageStore(persistence, 
sqlContext.sparkContext.getConf)
+  private val store = new LocalMessageStore(persistence)
 
-  private val messages = new TrieMap[Int, (String, Timestamp)]
+  private val messages = new TrieMap[Long, MQTTMessage]
 
-  private val initLock = new CountDownLatch(1)
+  @GuardedBy("this")
+  private var currentOffset: LongOffset = LongOffset(-1L)
 
-  private var offset = 0
+  @GuardedBy("this")
+  private var lastOffsetCommitted: LongOffset = LongOffset(-1L)
 
   private var client: MqttClient = _
 
-  private def fetchLastProcessedOffset(): Int = {
-    Try(store.maxProcessedOffset) match {
-      case Success(x) =>
-        log.info(s"Recovering from last stored offset $x")
-        x
-      case Failure(e) => 0
-    }
-  }
+  private[mqtt] def getCurrentOffset = currentOffset
 
   initialize()
   private def initialize(): Unit = {
 
     client = new MqttClient(brokerUrl, clientId, persistence)
-
     val callback = new MqttCallbackExtended() {
 
       override def messageArrived(topic_ : String, message: MqttMessage): Unit 
= synchronized {
-        initLock.await() // Wait for initialization to complete.
-        val temp = offset + 1
-        messages.put(temp, messageParser(message.getPayload))
-        offset = temp
-        log.trace(s"Message arrived, $topic_ $message")
+        val mqttMessage = new MQTTMessage(message, topic_)
+        val offset = currentOffset.offset + 1L
+        messages.put(offset, mqttMessage)
+        currentOffset = LongOffset(offset)
+        log.trace(s"Message arrived, $topic_ $mqttMessage")
       }
 
       override def deliveryComplete(token: IMqttDeliveryToken): Unit = {
@@ -121,116 +143,162 @@ class MQTTTextStreamSource(brokerUrl: String, 
persistence: MqttClientPersistence
     }
     client.setCallback(callback)
     client.connect(mqttConnectOptions)
-    client.subscribe(topic, qos)
     // It is not possible to initialize offset without `client.connect`
-    offset = fetchLastProcessedOffset()
-    initLock.countDown() // Release.
+    client.subscribe(topic, qos)
   }
 
-  /** Stop this source and free any resources it has allocated. */
-  override def stop(): Unit = {
-    client.disconnect()
-    persistence.close()
-    client.close()
+  override def setOffsetRange(
+      start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = synchronized 
{
+    startOffset = start.orElse(LongOffset(-1L))
+    endOffset = end.orElse(currentOffset)
   }
 
-  /** Returns the maximum available offset for this source. */
-  override def getOffset: Option[Offset] = {
-    if (offset == 0) {
-      None
-    } else {
-      Some(LongOffset(offset))
-    }
+  override def getStartOffset(): OffsetV2 = {
+    Option(startOffset).getOrElse(throw new IllegalStateException("start 
offset not set"))
+  }
+
+  override def getEndOffset(): OffsetV2 = {
+    Option(endOffset).getOrElse(throw new IllegalStateException("end offset 
not set"))
   }
 
-  /**
-   * Returns the data that is between the offsets (`start`, `end`]. When 
`start` is `None` then
-   * the batch should begin with the first available record. This method must 
always return the
-   * same data for a particular `start` and `end` pair.
-   */
-  override def getBatch(start: Option[Offset], end: Offset): DataFrame = 
synchronized {
-    val startIndex = 
start.getOrElse(LongOffset(0L)).asInstanceOf[LongOffset].offset.toInt
-    val endIndex = end.asInstanceOf[LongOffset].offset.toInt
-    val data: ArrayBuffer[(String, Timestamp)] = ArrayBuffer.empty
-    // Move consumed messages to persistent store.
-    (startIndex + 1 to endIndex).foreach { id =>
-      val element: (String, Timestamp) = messages.getOrElse(id, 
store.retrieve(id))
-      data += element
-      store.store(id, element)
-      messages.remove(id, element)
+  override def deserializeOffset(json: String): OffsetV2 = {
+    LongOffset(json.toLong)
+  }
+
+  override def readSchema(): StructType = {
+    MQTTStreamConstants.SCHEMA_DEFAULT
+  }
+
+  override def createDataReaderFactories(): 
java.util.List[DataReaderFactory[Row]] = {
+    val rawList: IndexedSeq[MQTTMessage] = synchronized {
+      val sliceStart = LongOffset.convert(startOffset).get.offset + 1
+      val sliceEnd = LongOffset.convert(endOffset).get.offset + 1
+      for (i <- sliceStart until sliceEnd) yield messages(i)
+    }
+    val spark = SparkSession.getActiveSession.get
+    val numPartitions = spark.sparkContext.defaultParallelism
+
+    val slices = Array.fill(numPartitions)(new ListBuffer[MQTTMessage])
+    rawList.zipWithIndex.foreach { case (r, idx) =>
+      slices(idx % numPartitions).append(r)
     }
-    log.trace(s"Get Batch invoked, ${data.mkString}")
-    import sqlContext.implicits._
-    data.toDF("value", "timestamp")
+
+    (0 until numPartitions).map { i =>
+      val slice = slices(i)
+      new DataReaderFactory[Row] {
+        override def createDataReader(): DataReader[Row] = new DataReader[Row] 
{
+          private var currentIdx = -1
+
+          override def next(): Boolean = {
+            currentIdx += 1
+            currentIdx < slice.size
+          }
+
+          override def get(): Row = {
+            Row(slice(currentIdx).id, slice(currentIdx).topic,
+              slice(currentIdx).payload, slice(currentIdx).timestamp)
+          }
+
+          override def close(): Unit = {}
+        }
+      }
+    }.toList.asJava
   }
 
-}
+  override def commit(end: OffsetV2): Unit = synchronized {
+    val newOffset = LongOffset.convert(end).getOrElse(
+      sys.error(s"MQTTStreamSource.commit() received an offset ($end) that did 
not " +
+        s"originate with an instance of this class")
+    )
 
-class MQTTStreamSourceProvider extends StreamSourceProvider with 
DataSourceRegister with Logging {
+    val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
 
-  override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType],
-      providerName: String, parameters: Map[String, String]): (String, 
StructType) = {
-    ("mqtt", MQTTStreamConstants.SCHEMA_DEFAULT)
+    if (offsetDiff < 0) {
+      sys.error(s"Offsets committed out of order: $lastOffsetCommitted 
followed by $end")
+    }
+
+    (lastOffsetCommitted.offset until newOffset.offset).foreach { x =>
+      messages.remove(x + 1)
+    }
+    lastOffsetCommitted = newOffset
+  }
+
+  /** Stop this source. */
+  override def stop(): Unit = synchronized {
+    client.disconnect()
+    persistence.close()
+    client.close()
   }
 
-  override def createSource(sqlContext: SQLContext, metadataPath: String,
-      schema: Option[StructType], providerName: String, parameters: 
Map[String, String]): Source = {
+  override def toString: String = s"MQTTStreamSource[brokerUrl: $brokerUrl, 
topic: $topic" +
+    s" clientId: $clientId]"
+}
+
+class MQTTStreamSourceProvider extends DataSourceV2
+  with MicroBatchReadSupport with DataSourceRegister with Logging {
 
+  override def createMicroBatchReader(schema: Optional[StructType],
+      checkpointLocation: String, parameters: DataSourceOptions): 
MicroBatchReader = {
     def e(s: String) = new IllegalArgumentException(s)
+    if (schema.isPresent) {
+      throw e("The mqtt source does not support a user-specified schema.")
+    }
 
-    val brokerUrl: String = parameters.getOrElse("brokerUrl", 
parameters.getOrElse("path",
-      throw e("Please provide a `brokerUrl` by specifying path or 
.options(\"brokerUrl\",...)")))
+    val brokerUrl = 
parameters.get("brokerUrl").orElse(parameters.get("path").orElse(null))
 
+    if (brokerUrl == null) {
+      throw e("Please provide a broker url, with option(\"brokerUrl\", ...).")
+    }
 
-    val persistence: MqttClientPersistence = parameters.get("persistence") 
match {
-      case Some("memory") => new MemoryPersistence()
-      case _ => val localStorage: Option[String] = 
parameters.get("localStorage")
+    val persistence: MqttClientPersistence = 
parameters.get("persistence").orElse("") match {
+      case "memory" => new MemoryPersistence()
+      case _ => val localStorage: String = 
parameters.get("localStorage").orElse("")
         localStorage match {
-          case Some(x) => new MqttDefaultFilePersistence(x)
-          case None => new MqttDefaultFilePersistence()
+          case "" => new MqttDefaultFilePersistence()
+          case x => new MqttDefaultFilePersistence(x)
         }
     }
 
-    val messageParserWithTimeStamp = (x: Array[Byte]) =>
-      (new String(x, Charset.defaultCharset()), Timestamp.valueOf(
-      MQTTStreamConstants.DATE_FORMAT.format(Calendar.getInstance().getTime)))
-
     // if default is subscribe everything, it leads to getting lot unwanted 
system messages.
-    val topic: String = parameters.getOrElse("topic",
-      throw e("Please specify a topic, by .options(\"topic\",...)"))
+    val topic: String = parameters.get("topic").orElse(null)
+    if (topic == null) {
+      throw e("Please specify a topic, by .options(\"topic\",...)")
+    }
 
-    val clientId: String = parameters.getOrElse("clientId", {
+    val clientId: String = parameters.get("clientId").orElse {
       log.warn("If `clientId` is not set, a random value is picked up." +
-        "\nRecovering from failure is not supported in such a case.")
-      MqttClient.generateClientId()})
+        " Recovering from failure is not supported in such a case.")
+      MqttClient.generateClientId()}
+
+    val username: String = parameters.get("username").orElse(null)
+    val password: String = parameters.get("password").orElse(null)
 
-    val username: Option[String] = parameters.get("username")
-    val password: Option[String] = parameters.get("password")
-    val connectionTimeout: Int = parameters.getOrElse("connectionTimeout",
+    val connectionTimeout: Int = parameters.get("connectionTimeout").orElse(
       MqttConnectOptions.CONNECTION_TIMEOUT_DEFAULT.toString).toInt
-    val keepAlive: Int = parameters.getOrElse("keepAlive", MqttConnectOptions
+    val keepAlive: Int = parameters.get("keepAlive").orElse(MqttConnectOptions
       .KEEP_ALIVE_INTERVAL_DEFAULT.toString).toInt
-    val mqttVersion: Int = parameters.getOrElse("mqttVersion", 
MqttConnectOptions
+    val mqttVersion: Int = 
parameters.get("mqttVersion").orElse(MqttConnectOptions
       .MQTT_VERSION_DEFAULT.toString).toInt
-    val cleanSession: Boolean = parameters.getOrElse("cleanSession", 
"false").toBoolean
-    val qos: Int = parameters.getOrElse("QoS", "1").toInt
-
+    val cleanSession: Boolean = 
parameters.get("cleanSession").orElse("true").toBoolean
+    val qos: Int = parameters.get("QoS").orElse("1").toInt
+    val autoReconnect: Boolean = 
parameters.get("autoReconnect").orElse("false").toBoolean
+    val maxInflight: Int = parameters.get("maxInflight").orElse("60").toInt
     val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions()
-    mqttConnectOptions.setAutomaticReconnect(true)
+    mqttConnectOptions.setAutomaticReconnect(autoReconnect)
     mqttConnectOptions.setCleanSession(cleanSession)
     mqttConnectOptions.setConnectionTimeout(connectionTimeout)
     mqttConnectOptions.setKeepAliveInterval(keepAlive)
     mqttConnectOptions.setMqttVersion(mqttVersion)
+    mqttConnectOptions.setMaxInflight(maxInflight)
     (username, password) match {
-      case (Some(u: String), Some(p: String)) =>
+      case (u: String, p: String) if u != null && p != null =>
         mqttConnectOptions.setUserName(u)
         mqttConnectOptions.setPassword(p.toCharArray)
       case _ =>
     }
 
-    new MQTTTextStreamSource(brokerUrl, persistence, topic, clientId,
-      messageParserWithTimeStamp, sqlContext, mqttConnectOptions, qos)
+    new  MQTTStreamSource(parameters, brokerUrl, persistence, topic, clientId,
+      mqttConnectOptions, qos)
   }
-
   override def shortName(): String = "mqtt"
 }

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
----------------------------------------------------------------------
diff --git 
a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
 
b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
index 84fd8c4..d7d2657 100644
--- 
a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
+++ 
b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala
@@ -18,15 +18,11 @@
 
 package org.apache.bahir.sql.streaming.mqtt
 
-import java.nio.ByteBuffer
+import java.io._
 import java.util
 
-import scala.reflect.ClassTag
-
 import org.eclipse.paho.client.mqttv3.{MqttClientPersistence, MqttPersistable, 
MqttPersistenceException}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.serializer.{JavaSerializer, Serializer, 
SerializerInstance}
+import scala.util.Try
 
 import org.apache.bahir.utils.Logging
 
@@ -35,16 +31,13 @@ import org.apache.bahir.utils.Logging
 trait MessageStore {
 
   /** Store a single id and corresponding serialized message */
-  def store[T: ClassTag](id: Int, message: T): Boolean
-
-  /** Retrieve messages corresponding to certain offset range */
-  def retrieve[T: ClassTag](start: Int, end: Int): Seq[T]
+  def store[T](id: Long, message: T): Boolean
 
   /** Retrieve message corresponding to a given id. */
-  def retrieve[T: ClassTag](id: Int): T
+  def retrieve[T](id: Long): T
 
   /** Highest offset we have stored */
-  def maxProcessedOffset: Int
+  def maxProcessedOffset: Long
 
 }
 
@@ -63,6 +56,52 @@ private[mqtt] class MqttPersistableData(bytes: Array[Byte]) 
extends MqttPersista
   override def getPayloadLength: Int = 0
 }
 
+trait Serializer {
+
+  def deserialize[T](x: Array[Byte]): T
+
+  def serialize[T](x: T): Array[Byte]
+}
+
+class JavaSerializer extends Serializer with Logging {
+
+  override def deserialize[T](x: Array[Byte]): T = {
+    val bis = new ByteArrayInputStream(x)
+    val in = new ObjectInputStream(bis)
+    val obj = if (in != null) {
+      val o = in.readObject()
+      Try(in.close()).recover { case t: Throwable => log.warn("failed to close 
stream", t) }
+      o
+    } else {
+      null
+    }
+    obj.asInstanceOf[T]
+  }
+
+  override def serialize[T](x: T): Array[Byte] = {
+    val bos = new ByteArrayOutputStream()
+    val out = new ObjectOutputStream(bos)
+    out.writeObject(x)
+    out.flush()
+    if (bos != null) {
+      val bytes: Array[Byte] = bos.toByteArray
+      Try(bos.close()).recover { case t: Throwable => log.warn("failed to 
close stream", t) }
+      bytes
+    } else {
+      null
+    }
+  }
+}
+
+object JavaSerializer {
+
+  private lazy val instance = new JavaSerializer()
+
+  def getInstance(): JavaSerializer = instance
+
+}
+
+
 /**
  * A message store to persist messages received. This is not intended to be 
thread safe.
  * It uses `MqttDefaultFilePersistence` for storing messages on disk locally 
on the client.
@@ -70,44 +109,35 @@ private[mqtt] class MqttPersistableData(bytes: 
Array[Byte]) extends MqttPersista
 private[mqtt] class LocalMessageStore(val persistentStore: 
MqttClientPersistence,
     val serializer: Serializer) extends MessageStore with Logging {
 
-  val classLoader = Thread.currentThread.getContextClassLoader
-
-  def this(persistentStore: MqttClientPersistence, conf: SparkConf) =
-    this(persistentStore, new JavaSerializer(conf))
+  def this(persistentStore: MqttClientPersistence) =
+    this(persistentStore, JavaSerializer.getInstance())
 
-  val serializerInstance: SerializerInstance = serializer.newInstance()
-
-  private def get(id: Int) = {
+  private def get(id: Long) = {
     persistentStore.get(id.toString).getHeaderBytes
   }
 
   import scala.collection.JavaConverters._
 
-  def maxProcessedOffset: Int = {
+  def maxProcessedOffset: Long = {
     val keys: util.Enumeration[_] = persistentStore.keys()
     keys.asScala.map(x => x.toString.toInt).max
   }
 
   /** Store a single id and corresponding serialized message */
-  override def store[T: ClassTag](id: Int, message: T): Boolean = {
-    val bytes: Array[Byte] = serializerInstance.serialize(message).array()
+  override def store[T](id: Long, message: T): Boolean = {
+    val bytes: Array[Byte] = serializer.serialize(message)
     try {
       persistentStore.put(id.toString, new MqttPersistableData(bytes))
       true
     } catch {
       case e: MqttPersistenceException => log.warn(s"Failed to store message 
Id: $id", e)
-      false
+        false
     }
   }
 
-  /** Retrieve messages corresponding to certain offset range */
-  override def retrieve[T: ClassTag](start: Int, end: Int): Seq[T] = {
-    (start until end).map(x => retrieve(x))
-  }
-
   /** Retrieve message corresponding to a given id. */
-  override def retrieve[T: ClassTag](id: Int): T = {
-    serializerInstance.deserialize(ByteBuffer.wrap(get(id)), classLoader)
+  override def retrieve[T](id: Long): T = {
+    serializer.deserialize(get(id))
   }
 
 }

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/test/bin/test-BAHIR-83.sh
----------------------------------------------------------------------
diff --git a/sql-streaming-mqtt/src/test/bin/test-BAHIR-83.sh 
b/sql-streaming-mqtt/src/test/bin/test-BAHIR-83.sh
new file mode 100755
index 0000000..659dd8c
--- /dev/null
+++ b/sql-streaming-mqtt/src/test/bin/test-BAHIR-83.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+set -o pipefail
+
+for i in `seq 100` ; do
+  mvn scalatest:test -pl sql-streaming-mqtt -q 
-Dsuites='*.BasicMQTTSourceSuite' | \
+    grep -q "TEST FAILED" && echo "$i: failed"
+done

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
 
b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
index 9c678cb..0a2a079 100644
--- 
a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
+++ 
b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala
@@ -22,8 +22,7 @@ import java.io.File
 import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
 import org.scalatest.BeforeAndAfter
 
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.SparkFunSuite
 
 import org.apache.bahir.utils.BahirUtils
 
@@ -31,9 +30,9 @@ import org.apache.bahir.utils.BahirUtils
 class LocalMessageStoreSuite extends SparkFunSuite with BeforeAndAfter {
 
   private val testData = Seq(1, 2, 3, 4, 5, 6)
-  private val javaSerializer: JavaSerializer = new JavaSerializer(new 
SparkConf())
+  private val javaSerializer: JavaSerializer = new JavaSerializer()
 
-  private val serializerInstance = javaSerializer.newInstance()
+  private val serializerInstance = javaSerializer
   private val tempDir: File = new File(System.getProperty("java.io.tmpdir") + 
"/mqtt-test2/")
   private val persistence: MqttDefaultFilePersistence =
     new MqttDefaultFilePersistence(tempDir.getAbsolutePath)
@@ -68,7 +67,7 @@ class LocalMessageStoreSuite extends SparkFunSuite with 
BeforeAndAfter {
   test("Max offset stored") {
     store.store(1, testData)
     store.store(10, testData)
-    val offset: Int = store.maxProcessedOffset
+    val offset = store.maxProcessedOffset
     assert(offset == 10)
   }
 

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
 
b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
index 38971a0..2ce72da 100644
--- 
a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
+++ 
b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala
@@ -18,31 +18,34 @@
 package org.apache.bahir.sql.streaming.mqtt
 
 import java.io.File
-import java.sql.Timestamp
+import java.util.Optional
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
-import scala.concurrent.Future
 
 import org.eclipse.paho.client.mqttv3.MqttException
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.{SharedSparkContext, SparkFunSuite}
-import org.apache.spark.sql.{DataFrame, SQLContext}
-import org.apache.spark.sql.execution.streaming.LongOffset
+import org.apache.spark.sql._
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQuery}
 
 import org.apache.bahir.utils.BahirUtils
 
-
 class MQTTStreamSourceSuite extends SparkFunSuite with SharedSparkContext with 
BeforeAndAfter {
 
   protected var mqttTestUtils: MQTTTestUtils = _
   protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") 
+ "/mqtt-test/")
 
   before {
+    tempDir.mkdirs()
+    if (!tempDir.exists()) {
+      throw new IllegalStateException("Unable to create temp directories.")
+    }
+    tempDir.deleteOnExit()
     mqttTestUtils = new MQTTTestUtils(tempDir)
     mqttTestUtils.setup()
-    tempDir.mkdirs()
   }
 
   after {
@@ -52,16 +55,44 @@ class MQTTStreamSourceSuite extends SparkFunSuite with 
SharedSparkContext with B
 
   protected val tmpDir: String = tempDir.getAbsolutePath
 
-  protected def createStreamingDataframe(dir: String = tmpDir): (SQLContext, 
DataFrame) = {
+  protected def writeStreamResults(sqlContext: SQLContext, dataFrame: 
DataFrame): StreamingQuery = {
+    import sqlContext.implicits._
+    val query: StreamingQuery = dataFrame.selectExpr("CAST(payload AS 
STRING)").as[String]
+      .writeStream.format("parquet").start(s"$tmpDir/t.parquet")
+    while (!query.status.isTriggerActive) {
+      Thread.sleep(20)
+    }
+    query
+  }
+
+  protected def readBackStreamingResults(sqlContext: SQLContext): 
mutable.Buffer[String] = {
+    import sqlContext.implicits._
+    val asList =
+      sqlContext.read
+        .parquet(s"$tmpDir/t.parquet").as[String]
+        .collectAsList().asScala
+    asList
+  }
+
+  protected def createStreamingDataframe(dir: String = tmpDir,
+      filePersistence: Boolean = false): (SQLContext, DataFrame) = {
 
     val sqlContext: SQLContext = new SQLContext(sc)
 
     sqlContext.setConf("spark.sql.streaming.checkpointLocation", tmpDir)
 
-    val dataFrame: DataFrame =
+    val ds: DataStreamReader =
       
sqlContext.readStream.format("org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider")
-        .option("topic", "test").option("localStorage", 
dir).option("clientId", "clientId")
-        .option("QoS", "2").load("tcp://" + mqttTestUtils.brokerUri)
+        .option("topic", "test").option("clientId", 
"clientId").option("connectionTimeout", "120")
+        .option("keepAlive", "1200").option("maxInflight", 
"120").option("autoReconnect", "false")
+        .option("cleanSession", "true").option("QoS", "2")
+
+    val dataFrame = if (!filePersistence) {
+      ds.option("persistence", "memory").load("tcp://" + 
mqttTestUtils.brokerUri)
+    } else {
+      ds.option("persistence", "file").option("localStorage", tmpDir)
+        .load("tcp://" + mqttTestUtils.brokerUri)
+    }
     (sqlContext, dataFrame)
   }
 
@@ -69,31 +100,16 @@ class MQTTStreamSourceSuite extends SparkFunSuite with 
SharedSparkContext with B
 
 class BasicMQTTSourceSuite extends MQTTStreamSourceSuite {
 
-  private def writeStreamResults(sqlContext: SQLContext,
-      dataFrame: DataFrame, waitDuration: Long): Boolean = {
-    import sqlContext.implicits._
-    dataFrame.as[(String, 
Timestamp)].writeStream.format("parquet").start(s"$tmpDir/t.parquet")
-      .awaitTermination(waitDuration)
-  }
-
-  private def readBackStreamingResults(sqlContext: SQLContext): 
mutable.Buffer[String] = {
-    import sqlContext.implicits._
-    val asList =
-      sqlContext.read.schema(MQTTStreamConstants.SCHEMA_DEFAULT)
-        .parquet(s"$tmpDir/t.parquet").as[(String, Timestamp)].map(_._1)
-        .collectAsList().asScala
-    asList
-  }
-
   test("basic usage") {
 
     val sendMessage = "MQTT is a message queue."
 
-    mqttTestUtils.publishData("test", sendMessage)
-
     val (sqlContext: SQLContext, dataFrame: DataFrame) = 
createStreamingDataframe()
 
-    writeStreamResults(sqlContext, dataFrame, 5000)
+    val query = writeStreamResults(sqlContext, dataFrame)
+    mqttTestUtils.publishData("test", sendMessage)
+    query.processAllAvailable()
+    query.awaitTermination(10000)
 
     val resultBuffer: mutable.Buffer[String] = 
readBackStreamingResults(sqlContext)
 
@@ -101,88 +117,58 @@ class BasicMQTTSourceSuite extends MQTTStreamSourceSuite {
     assert(resultBuffer.head == sendMessage)
   }
 
-  // TODO: reinstate this test after fixing BAHIR-83
-  ignore("Send and receive 100 messages.") {
+  test("Send and receive 50 messages.") {
 
     val sendMessage = "MQTT is a message queue."
 
-    import scala.concurrent.ExecutionContext.Implicits.global
-
     val (sqlContext: SQLContext, dataFrame: DataFrame) = 
createStreamingDataframe()
 
-    Future {
-      Thread.sleep(2000)
-      mqttTestUtils.publishData("test", sendMessage, 100)
-    }
+    val q = writeStreamResults(sqlContext, dataFrame)
 
-    writeStreamResults(sqlContext, dataFrame, 10000)
+    mqttTestUtils.publishData("test", sendMessage, 50)
+    q.processAllAvailable()
+    q.awaitTermination(10000)
 
     val resultBuffer: mutable.Buffer[String] = 
readBackStreamingResults(sqlContext)
 
-    assert(resultBuffer.size == 100)
+    assert(resultBuffer.size == 50)
     assert(resultBuffer.head == sendMessage)
   }
 
   test("no server up") {
     val provider = new MQTTStreamSourceProvider
     val sqlContext: SQLContext = new SQLContext(sc)
-    val parameters = Map("brokerUrl" -> "tcp://localhost:1883", "topic" -> 
"test",
-      "localStorage" -> tmpDir)
+    val parameters = new DataSourceOptions(Map("brokerUrl" ->
+      "tcp://localhost:1881", "topic" -> "test", "localStorage" -> 
tmpDir).asJava)
     intercept[MqttException] {
-      provider.createSource(sqlContext, "", None, "", parameters)
+      provider.createMicroBatchReader(Optional.empty(), tempDir.toString, 
parameters)
     }
   }
 
   test("params not provided.") {
     val provider = new MQTTStreamSourceProvider
-    val sqlContext: SQLContext = new SQLContext(sc)
-    val parameters = Map("brokerUrl" -> mqttTestUtils.brokerUri,
-      "localStorage" -> tmpDir)
+    val parameters = new DataSourceOptions(Map("brokerUrl" -> 
mqttTestUtils.brokerUri,
+      "localStorage" -> tmpDir).asJava)
     intercept[IllegalArgumentException] {
-      provider.createSource(sqlContext, "", None, "", parameters)
+      provider.createMicroBatchReader(Optional.empty(), tempDir.toString, 
parameters)
     }
     intercept[IllegalArgumentException] {
-      provider.createSource(sqlContext, "", None, "", Map())
+      provider.createMicroBatchReader(Optional.empty(), tempDir.toString, 
DataSourceOptions.empty())
     }
   }
 
-  // TODO: reinstate this test after fixing BAHIR-83
-  ignore("Recovering offset from the last processed offset.") {
-    val sendMessage = "MQTT is a message queue."
-
-    import scala.concurrent.ExecutionContext.Implicits.global
-
-    val (sqlContext: SQLContext, dataFrame: DataFrame) =
-      createStreamingDataframe()
-
-    Future {
-      Thread.sleep(2000)
-      mqttTestUtils.publishData("test", sendMessage, 100)
-    }
-
-    writeStreamResults(sqlContext, dataFrame, 10000)
-    // On restarting the source with same params, it should begin from the 
offset - the
-    // previously running stream left off.
-    val provider = new MQTTStreamSourceProvider
-    val parameters = Map("brokerUrl" -> ("tcp://" + mqttTestUtils.brokerUri), 
"topic" -> "test",
-      "localStorage" -> tmpDir, "clientId" -> "clientId", "QoS" -> "2")
-    val offset: Long = provider.createSource(sqlContext, "", None, "", 
parameters)
-      .getOffset.get.asInstanceOf[LongOffset].offset
-    assert(offset == 100L)
-  }
-
 }
 
 class StressTestMQTTSource extends MQTTStreamSourceSuite {
 
   // Run with -Xmx1024m
-  ignore("Send and receive messages of size 250MB.") {
+  test("Send and receive messages of size 100MB.") {
 
     val freeMemory: Long = Runtime.getRuntime.freeMemory()
 
     log.info(s"Available memory before test run is ${freeMemory / (1024 * 
1024)}MB.")
 
-    val noOfMsgs = (250 * 1024 * 1024) / (500 * 1024) // 512
+    val noOfMsgs: Int = (100 * 1024 * 1024) / (500 * 1024) // 204
 
     val messageBuilder = new StringBuilder()
     for (i <- 0 until (500 * 1024)) yield messageBuilder.append(((i % 26) + 
65).toChar)
@@ -190,22 +176,14 @@ class StressTestMQTTSource extends MQTTStreamSourceSuite {
 
     val (sqlContext: SQLContext, dataFrame: DataFrame) = 
createStreamingDataframe()
 
-    import scala.concurrent.ExecutionContext.Implicits.global
-    Future {
-      Thread.sleep(2000)
-      mqttTestUtils.publishData("test", sendMessage, noOfMsgs.toInt)
-    }
-
-    import sqlContext.implicits._
-
-    dataFrame.as[(String, Timestamp)].writeStream
-      .format("parquet")
-      .start(s"$tmpDir/t.parquet")
-      .awaitTermination(25000)
+    val query = writeStreamResults(sqlContext, dataFrame)
+    mqttTestUtils.publishData("test", sendMessage, noOfMsgs )
+    query.processAllAvailable()
+    query.awaitTermination(25000)
 
     val messageCount =
-      sqlContext.read.schema(MQTTStreamConstants.SCHEMA_DEFAULT)
-        .parquet(s"$tmpDir/t.parquet").as[(String, Timestamp)].map(_._1)
+      sqlContext.read
+        .parquet(s"$tmpDir/t.parquet")
         .count()
     assert(messageCount == noOfMsgs)
   }

http://git-wip-us.apache.org/repos/asf/bahir/blob/b3902bac/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
 
b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
index 9c7399f..817ec9a 100644
--- 
a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
+++ 
b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTTestUtils.scala
@@ -22,15 +22,14 @@ import java.net.{ServerSocket, URI}
 
 import org.apache.activemq.broker.{BrokerService, TransportConnector}
 import org.eclipse.paho.client.mqttv3._
-import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+import org.eclipse.paho.client.mqttv3.persist.{MemoryPersistence, 
MqttDefaultFilePersistence}
 
 import org.apache.bahir.utils.Logging
 
 
 class MQTTTestUtils(tempDir: File, port: Int = 0) extends Logging {
 
-  private val persistenceDir = tempDir.getAbsolutePath
-  private val brokerHost = "localhost"
+  private val brokerHost = "127.0.0.1"
   private val brokerPort: Int = if (port == 0) findFreePort() else port
 
   private var broker: BrokerService = _
@@ -60,18 +59,21 @@ class MQTTTestUtils(tempDir: File, port: Int = 0) extends 
Logging {
   def teardown(): Unit = {
     if (broker != null) {
       broker.stop()
-      broker = null
     }
     if (connector != null) {
       connector.stop()
       connector = null
     }
+    while (!broker.isStopped) {
+      Thread.sleep(50)
+    }
+    broker = null
   }
 
   def publishData(topic: String, data: String, N: Int = 1): Unit = {
     var client: MqttClient = null
     try {
-      val persistence = new MqttDefaultFilePersistence(persistenceDir)
+      val persistence = new MemoryPersistence()
       client = new MqttClient("tcp://" + brokerUri, 
MqttClient.generateClientId(), persistence)
       client.connect()
       if (client.isConnected) {
@@ -81,7 +83,7 @@ class MQTTTestUtils(tempDir: File, port: Int = 0) extends 
Logging {
             Thread.sleep(20)
             val message = new MqttMessage(data.getBytes())
             message.setQos(2)
-            message.setRetained(true)
+            // message.setId(i) setting id has no effect.
             msgTopic.publish(message)
           } catch {
             case e: MqttException =>

Reply via email to