Repository: bahir Updated Branches: refs/heads/master a45bd8421 -> 5cfd7ac31
[BAHIR-175] Fix MQTT recovery after checkpoint Closes #79 Project: http://git-wip-us.apache.org/repos/asf/bahir/repo Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/5cfd7ac3 Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/5cfd7ac3 Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/5cfd7ac3 Branch: refs/heads/master Commit: 5cfd7ac3154621b1780e2eb4719731030fc7d80a Parents: a45bd84 Author: Lukasz Antoniak <[email protected]> Authored: Wed Dec 19 13:23:58 2018 -0800 Committer: Luciano Resende <[email protected]> Committed: Wed Jan 9 16:14:00 2019 -0800 ---------------------------------------------------------------------- .../sql/streaming/mqtt/MQTTStreamSource.scala | 9 ++++-- .../bahir/sql/streaming/mqtt/MessageStore.scala | 9 +++++- .../streaming/mqtt/MQTTStreamSourceSuite.scala | 34 +++++++++++++++++++- 3 files changed, 47 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/bahir/blob/5cfd7ac3/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala index a40ff51..7146ecc 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala @@ -101,9 +101,9 @@ class MQTTStreamSource(options: DataSourceOptions, brokerUrl: String, persistenc /* Older than last N messages, will not be checked for redelivery. */ val backLog = options.getInt("autopruning.backlog", 500) - private val store = new LocalMessageStore(persistence) + private[mqtt] val store = new LocalMessageStore(persistence) - private val messages = new TrieMap[Long, MQTTMessage] + private[mqtt] val messages = new TrieMap[Long, MQTTMessage] @GuardedBy("this") private var currentOffset: LongOffset = LongOffset(-1L) @@ -125,6 +125,7 @@ class MQTTStreamSource(options: DataSourceOptions, brokerUrl: String, persistenc val mqttMessage = new MQTTMessage(message, topic_) val offset = currentOffset.offset + 1L messages.put(offset, mqttMessage) + store.store(offset, mqttMessage) currentOffset = LongOffset(offset) log.trace(s"Message arrived, $topic_ $mqttMessage") } @@ -172,7 +173,8 @@ class MQTTStreamSource(options: DataSourceOptions, brokerUrl: String, persistenc val rawList: IndexedSeq[MQTTMessage] = synchronized { val sliceStart = LongOffset.convert(startOffset).get.offset + 1 val sliceEnd = LongOffset.convert(endOffset).get.offset + 1 - for (i <- sliceStart until sliceEnd) yield messages(i) + for (i <- sliceStart until sliceEnd) yield + messages.getOrElse(i, store.retrieve[MQTTMessage](i)) } val spark = SparkSession.getActiveSession.get val numPartitions = spark.sparkContext.defaultParallelism @@ -218,6 +220,7 @@ class MQTTStreamSource(options: DataSourceOptions, brokerUrl: String, persistenc (lastOffsetCommitted.offset until newOffset.offset).foreach { x => messages.remove(x + 1) + store.remove(x + 1) } lastOffsetCommitted = newOffset } http://git-wip-us.apache.org/repos/asf/bahir/blob/5cfd7ac3/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala index d7d2657..30ec7a6 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala @@ -39,6 +39,9 @@ trait MessageStore { /** Highest offset we have stored */ def maxProcessedOffset: Long + /** Remove message corresponding to a given id. */ + def remove[T](id: Long): Unit + } private[mqtt] class MqttPersistableData(bytes: Array[Byte]) extends MqttPersistable { @@ -118,7 +121,7 @@ private[mqtt] class LocalMessageStore(val persistentStore: MqttClientPersistence import scala.collection.JavaConverters._ - def maxProcessedOffset: Long = { + override def maxProcessedOffset: Long = { val keys: util.Enumeration[_] = persistentStore.keys() keys.asScala.map(x => x.toString.toInt).max } @@ -140,4 +143,8 @@ private[mqtt] class LocalMessageStore(val persistentStore: MqttClientPersistence serializer.deserialize(get(id)) } + override def remove[T](id: Long): Unit = { + persistentStore.remove(id.toString) + } + } http://git-wip-us.apache.org/repos/asf/bahir/blob/5cfd7ac3/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala ---------------------------------------------------------------------- diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala index a7eb770..c4e340c 100644 --- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala @@ -23,8 +23,13 @@ import java.util.Optional import scala.collection.JavaConverters._ import scala.collection.mutable +import org.eclipse.paho.client.mqttv3.MqttConnectOptions import org.eclipse.paho.client.mqttv3.MqttException +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually +import org.scalatest.time +import org.scalatest.time.Span import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.sql._ @@ -33,7 +38,8 @@ import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQuery} import org.apache.bahir.utils.FileHelper -class MQTTStreamSourceSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { +class MQTTStreamSourceSuite extends SparkFunSuite + with Eventually with SharedSparkContext with BeforeAndAfter { protected var mqttTestUtils: MQTTTestUtils = _ protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test/") @@ -136,6 +142,32 @@ class BasicMQTTSourceSuite extends MQTTStreamSourceSuite { assert(resultBuffer.head == sendMessage) } + test("messages persisted in store") { + val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext + val source = new MQTTStreamSource( + DataSourceOptions.empty(), "tcp://" + mqttTestUtils.brokerUri, new MemoryPersistence(), + "test", "clientId", new MqttConnectOptions(), 2 + ) + val payload = "MQTT is a message queue." + mqttTestUtils.publishData("test", payload) + eventually(timeout(Span(5, time.Seconds)), interval(Span(500, time.Millis))) { + val message = source.store.retrieve(0).asInstanceOf[Object] + assert(message != null) + } + // Clear in-memory cache to simulate recovery. + source.messages.clear() + source.setOffsetRange(Optional.empty(), Optional.empty()) + var message: Row = null + for (f <- source.createDataReaderFactories().asScala) { + val dataReader = f.createDataReader() + if (dataReader.next()) { + message = dataReader.get() + } + } + source.commit(source.getCurrentOffset) + assert(payload == new String(message.getAs[Array[Byte]](2), "UTF-8")) + } + test("no server up") { val provider = new MQTTStreamSourceProvider val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext
