Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21700#discussion_r202927242
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 ---
    @@ -64,21 +64,122 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
         require(!StateStore.isMaintenanceRunning)
       }
     
    +  def updateVersionTo(provider: StateStoreProvider, currentVersion: => Int,
    +                      targetVersion: Int): Int = {
    +    var newCurrentVersion = currentVersion
    +    for (i <- newCurrentVersion + 1 to targetVersion) {
    +      val store = provider.getStore(newCurrentVersion)
    +      put(store, "a", i)
    +      store.commit()
    +      newCurrentVersion += 1
    +    }
    +    require(newCurrentVersion === targetVersion)
    +    newCurrentVersion
    +  }
    +
    +  test("retaining only two latest versions when 
MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") {
    +    val provider = newStoreProvider(opId = Random.nextInt, partition = 0,
    +      numOfVersToRetainInMemory = 2)
    +
    +    def restoreOriginValues(map: provider.MapType): Map[String, Int] = {
    +      map.asScala.map(entry => rowToString(entry._1) -> 
rowToInt(entry._2)).toMap
    +    }
    +
    +    var currentVersion = 0
    +    currentVersion = updateVersionTo(provider, currentVersion, 1)
    +    assert(getData(provider) === Set("a" -> 1))
    +    var loadedMaps = provider.getClonedLoadedMaps()
    +    assert(loadedMaps.size() === 1)
    +    assert(loadedMaps.firstKey() === 1L)
    +    assert(restoreOriginValues(loadedMaps.get(1L)) === Map("a" -> 1))
    +
    +    currentVersion = updateVersionTo(provider, currentVersion, 2)
    +    assert(getData(provider) === Set("a" -> 2))
    +    loadedMaps = provider.getClonedLoadedMaps()
    +    assert(loadedMaps.size() === 2)
    +    assert(loadedMaps.firstKey() === 2L)
    +    assert(loadedMaps.lastKey() === 1L)
    +    assert(restoreOriginValues(loadedMaps.get(2L)) === Map("a" -> 2))
    +    assert(restoreOriginValues(loadedMaps.get(1L)) === Map("a" -> 1))
    +
    +    // this trigger exceeding cache and 1 will be evicted
    +    currentVersion = updateVersionTo(provider, currentVersion, 3)
    +    assert(getData(provider) === Set("a" -> 3))
    +    loadedMaps = provider.getClonedLoadedMaps()
    +    assert(loadedMaps.size() === 2)
    +    assert(loadedMaps.firstKey() === 3L)
    +    assert(loadedMaps.lastKey() === 2L)
    +    assert(restoreOriginValues(loadedMaps.get(3L)) === Map("a" -> 3))
    +    assert(restoreOriginValues(loadedMaps.get(2L)) === Map("a" -> 2))
    --- End diff --
    
    this can be boiled down to a convenience method as well to reduce the 
verbosity `def checkVersion(version: Int, expectedData: Map[String, Int])`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to