Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/21700#discussion_r202926467
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
---
@@ -64,21 +64,122 @@ class StateStoreSuite extends
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
require(!StateStore.isMaintenanceRunning)
}
+ def updateVersionTo(provider: StateStoreProvider, currentVersion: => Int,
+ targetVersion: Int): Int = {
+ var newCurrentVersion = currentVersion
+ for (i <- newCurrentVersion + 1 to targetVersion) {
+ val store = provider.getStore(newCurrentVersion)
+ put(store, "a", i)
+ store.commit()
+ newCurrentVersion += 1
+ }
+ require(newCurrentVersion === targetVersion)
+ newCurrentVersion
+ }
+
+ test("retaining only two latest versions when
MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") {
+ val provider = newStoreProvider(opId = Random.nextInt, partition = 0,
+ numOfVersToRetainInMemory = 2)
+
+ def restoreOriginValues(map: provider.MapType): Map[String, Int] = {
+ map.asScala.map(entry => rowToString(entry._1) ->
rowToInt(entry._2)).toMap
+ }
+
+ var currentVersion = 0
+ currentVersion = updateVersionTo(provider, currentVersion, 1)
+ assert(getData(provider) === Set("a" -> 1))
+ var loadedMaps = provider.getClonedLoadedMaps()
+ assert(loadedMaps.size() === 1)
+ assert(loadedMaps.firstKey() === 1L)
+ assert(restoreOriginValues(loadedMaps.get(1L)) === Map("a" -> 1))
+
+ currentVersion = updateVersionTo(provider, currentVersion, 2)
+ assert(getData(provider) === Set("a" -> 2))
+ loadedMaps = provider.getClonedLoadedMaps()
+ assert(loadedMaps.size() === 2)
+ assert(loadedMaps.firstKey() === 2L)
+ assert(loadedMaps.lastKey() === 1L)
--- End diff --
You can make this a convenient function `def checkLoadedVersions(num: Int,
earliest: Int, latest: Int)`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]