Github user HeartSaVioR commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21700#discussion_r203577561
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 ---
    @@ -64,21 +66,143 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
         require(!StateStore.isMaintenanceRunning)
       }
     
    +  def updateVersionTo(
    +      provider: StateStoreProvider,
    +      currentVersion: Int,
    +      targetVersion: Int): Int = {
    +    var newCurrentVersion = currentVersion
    +    for (i <- newCurrentVersion until targetVersion) {
    +      newCurrentVersion = incrementVersion(provider, i)
    +    }
    +    require(newCurrentVersion === targetVersion)
    +    newCurrentVersion
    +  }
    +
    +  def incrementVersion(provider: StateStoreProvider, currentVersion: Int): 
Int = {
    +    val store = provider.getStore(currentVersion)
    +    put(store, "a", currentVersion + 1)
    +    store.commit()
    +    currentVersion + 1
    +  }
    +
    +  def checkLoadedVersions(
    +      loadedMaps: util.SortedMap[Long, ProviderMapType],
    +      count: Int,
    +      earliestKey: Long,
    +      latestKey: Long): Unit = {
    +    assert(loadedMaps.size() === count)
    +    assert(loadedMaps.firstKey() === earliestKey)
    +    assert(loadedMaps.lastKey() === latestKey)
    +  }
    +
    +  def checkVersion(
    +      loadedMaps: util.SortedMap[Long, ProviderMapType],
    +      version: Long,
    +      expectedData: Map[String, Int]): Unit = {
    +
    +    val originValueMap = loadedMaps.get(version).asScala.map { entry =>
    +      rowToString(entry._1) -> rowToInt(entry._2)
    +    }.toMap
    +
    +    assert(originValueMap === expectedData)
    +  }
    +
    +  test("retaining only two latest versions when 
MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") {
    +    val provider = newStoreProvider(opId = Random.nextInt, partition = 0,
    +      numOfVersToRetainInMemory = 2)
    +
    +    var currentVersion = 0
    +
    +    // commit the ver 1 : cache will have one element
    +    currentVersion = incrementVersion(provider, currentVersion)
    +    assert(getData(provider) === Set("a" -> 1))
    +    var loadedMaps = provider.getClonedLoadedMaps()
    +    checkLoadedVersions(loadedMaps, 1, 1L, 1L)
    --- End diff --
    
    Yeah I'd add 'L' everywhere if the type of literal number is long so that 
we don't rely on autocasting and be sure about the type explicitly, but no 
strong opinion about this. I can follow existing Spark preferences.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to