neilramaswamy commented on code in PR #48853:
URL: https://github.com/apache/spark/pull/48853#discussion_r1859114729


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##########
@@ -19,274 +19,575 @@ package org.apache.spark.sql.execution.streaming
 import java.time.Duration
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.metric.SQLMetric
 import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
-import 
org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec, 
StateStore}
+import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
RangeKeyScanStateEncoderSpec, StateStore}
+import org.apache.spark.sql.streaming.TTLConfig
 import org.apache.spark.sql.types._
 
-object StateTTLSchema {
-  val TTL_VALUE_ROW_SCHEMA: StructType =
-    StructType(Array(StructField("__dummy__", NullType)))
-}
-
 /**
- * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]].
+ * Any state variable that wants to support TTL must implement this trait,
+ * which they can do by extending [[OneToOneTTLState]] or 
[[OneToManyTTLState]].
  *
- * @param groupingKey grouping key for which ttl is set
- * @param expirationMs expiration time for the grouping key
- */
-case class SingleKeyTTLRow(
-    groupingKey: UnsafeRow,
-    expirationMs: Long)
-
-/**
- * Encapsulates the ttl row information stored in [[CompositeKeyTTLStateImpl]].
+ * The only required methods here are ones relating to evicting expired and all
+ * state, via clearExpiredStateForAllKeys and clearAllStateForElementKey,
+ * respectively. How classes do this is implementation detail, but the general
+ * pattern is to use secondary indexes to make sure cleanup scans
+ * theta(records to evict), not theta(all records).
  *
- * @param groupingKey grouping key for which ttl is set
- * @param userKey user key for which ttl is set
- * @param expirationMs expiration time for the grouping key
- */
-case class CompositeKeyTTLRow(
-   groupingKey: UnsafeRow,
-   userKey: UnsafeRow,
-   expirationMs: Long)
-
-/**
- * Represents the underlying state for secondary TTL Index for a user defined
- * state variable.
+ * There are two broad patterns of implementing stateful variables, and thus
+ * there are two broad patterns for implementing TTL. The first is when there
+ * is a one-to-one mapping between an element key [1] and a value; the primary
+ * and secondary index management for this case is implemented by
+ * [[OneToOneTTLState]]. When a single element key can have multiple values,
+ * all of which can expire at their own, unique times, then
+ * [[OneToManyTTLState]] should be used.
+ *
+ * In either case, implementations need to use some sort of secondary index
+ * that orders element keys by expiration time. This base functionality
+ * is provided by methods in this trait that read/write/delete to the
+ * so-called "TTL index". It is a secondary index with the layout of
+ * (expirationMs, elementKey) -> EMPTY_ROW. The expirationMs is big-endian
+ * encoded to allow for efficient range scans to find all expired keys.
+ *
+ * TTLState (or any abstract sub-classes) should never deal with encoding or
+ * decoding UnsafeRows to and from their user-facing types. The stateful 
variable
+ * themselves should be doing this; all other TTLState sub-classes should be 
concerned
+ * only with writing, reading, and deleting UnsafeRows and their associated
+ * expirations from the primary and secondary indexes. [2]
+ *
+ * [1]. You might ask, why call it "element key" instead of "grouping key"?
+ *      This is because a single grouping key might have multiple elements, as 
in
+ *      the case of a map, which has composite keys of the form (groupingKey, 
mapKey).
+ *      In the case of ValueState, though, the element key is the grouping key.
+ *      To generalize to both cases, this class should always use the term 
elementKey.)
  *
- * This state allows Spark to query ttl values based on expiration time
- * allowing efficient ttl cleanup.
+ * [2]. You might also ask, why design it this way? We want the TTLState 
abstract
+ *      sub-classes to write to both the primary and secondary indexes, since 
they
+ *      both need to stay in sync; co-locating the logic is cleanest.
  */
 trait TTLState {
+  // Name of the state variable, e.g. the string the user passes to 
get{Value/List/Map}State
+  // in the init() method of a StatefulProcessor.
+  def stateName: String
+
+  // The StateStore instance used to store the state. There is only one 
instance shared
+  // among the primary and secondary indexes, since it uses virtual column 
families
+  // to keep the indexes separate.
+  def store: StateStore
+
+  // The schema of the primary key for the state variable. For value and list 
state, this
+  // is the grouping key. For map state, this is the composite key of the 
grouping key and
+  // a map key.
+  def elementKeySchema: StructType
+
+  // The timestamp at which the batch is being processed. All state variables 
that have
+  // an expiration at or before this timestamp must be cleaned up.
+  def batchTimestampMs: Long
+
+  // The configuration for this run of the streaming query. It may change 
between runs
+  // (e.g. user sets ttlConfig1, stops their query, updates to ttlConfig2, and 
then
+  // resumes their query).
+  def ttlConfig: TTLConfig
+
+  // A map from metric name to the underlying SQLMetric. This should not be 
updated
+  // by the underlying state variable, as the TTL state implementation should 
be
+  // handling all reads/writes/deletes to the indexes.
+  def metrics: Map[String, SQLMetric] = Map.empty
+
+  private final val TTL_INDEX = "$ttl_" + stateName
+  private final val TTL_INDEX_KEY_SCHEMA = getTTLRowKeySchema(elementKeySchema)
+  private final val TTL_EMPTY_VALUE_ROW_SCHEMA: StructType =
+    StructType(Array(StructField("__empty__", NullType)))
+
+  private final val TTL_ENCODER = new TTLEncoder(elementKeySchema)
+
+  // Empty row used for values
+  private final val TTL_EMPTY_VALUE_ROW =
+    
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+  protected final def ttlExpirationMs = StateTTL
+    .calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  store.createColFamilyIfAbsent(
+    TTL_INDEX,
+    TTL_INDEX_KEY_SCHEMA,
+    TTL_EMPTY_VALUE_ROW_SCHEMA,
+    RangeKeyScanStateEncoderSpec(TTL_INDEX_KEY_SCHEMA, Seq(0)),
+    isInternal = true
+  )
+
+  protected def insertIntoTTLIndex(expirationMs: Long, elementKey: UnsafeRow): 
Unit = {
+    val secondaryIndexKey = TTL_ENCODER.encodeTTLRow(expirationMs, elementKey)
+    store.put(secondaryIndexKey, TTL_EMPTY_VALUE_ROW, TTL_INDEX)
+  }
+
+  // The deleteFromTTLIndex overload that takes an expiration time and 
elementKey as an
+  // argument is used when we need to _construct_ the key to delete from the 
TTL index.
+  //
+  // If we know the timestamp to delete and the elementKey, but don't have a 
pre-constructed
+  // UnsafeRow, then you should use this method to delete from the TTL index.
+  protected def deleteFromTTLIndex(expirationMs: Long, elementKey: UnsafeRow): 
Unit = {
+    val secondaryIndexKey = TTL_ENCODER.encodeTTLRow(expirationMs, elementKey)
+    store.remove(secondaryIndexKey, TTL_INDEX)
+  }
+
+  // The deleteFromTTLIndex overload that takes an UnsafeRow as an argument is 
used when
+  // we're deleting elements from the TTL index that we are iterating over.
+  //
+  // If we were to use the other deleteFromTTLIndex method, we would have to 
re-encode the
+  // components into an UnsafeRow. It is more efficient to just pass the 
UnsafeRow that we
+  // read from the iterator.
+  protected def deleteFromTTLIndex(ttlKey: UnsafeRow): Unit = {
+    store.remove(ttlKey, TTL_INDEX)
+  }
+
+  private[sql] def toTTLRow(ttlKey: UnsafeRow): TTLRow = {
+    val expirationMs = ttlKey.getLong(0)
+    val elementKey = ttlKey.getStruct(1, TTL_INDEX_KEY_SCHEMA.length)
+    TTLRow(elementKey, expirationMs)
+  }
+
+  private[sql] def getTTLRows(): Iterator[TTLRow] = {
+    store.iterator(TTL_INDEX).map(kv => toTTLRow(kv.key))
+  }
+
+  // Returns an Iterator over all the keys in the TTL index that have expired. 
This method
+  // does not delete the keys from the TTL index; it is the responsibility of 
the caller
+  // to do so.
+  //
+  // The schema of the UnsafeRow returned by this iterator is (expirationMs, 
elementKey).
+  protected def ttlEvictionIterator(): Iterator[UnsafeRow] = {
+    val ttlIterator = store.iterator(TTL_INDEX)
+
+    // Recall that the format is (expirationMs, elementKey) -> 
TTL_EMPTY_VALUE_ROW, so
+    // kv.value doesn't ever need to be used.
+    ttlIterator.takeWhile(kv => {
+      val expirationMs = kv.key.getLong(0)
+      StateTTL.isExpired(expirationMs, batchTimestampMs)
+    }).map(_.key)
+  }
+
+
+  // Encapsulates a row stored in a TTL index.
+  protected case class TTLRow(elementKey: UnsafeRow, expirationMs: Long)
 
   /**
-   * Perform the user state clean up based on ttl values stored in
-   * this state. NOTE that its not safe to call this operation concurrently
-   * when the user can also modify the underlying State. Cleanup should be 
initiated
-   * after arbitrary state operations are completed by the user.
+   * Evicts the state associated with this stateful variable that has expired
+   * due to TTL. The eviction applies to all grouping keys, and to all indexes,
+   * primary or secondary.
+   *
+   * This method can be called at any time in the micro-batch execution,
+   * as long as it is allowed to complete before subsequent state operations 
are
+   * issued. Operations to the state variable should not be issued 
concurrently while
+   * this is running.
+   *
+   * (Why? Some cleanup operations leave the state store in an inconsistent 
state while
+   * they are doing cleanup. For example, they may choose to get an iterator 
over all
+   * of the existing values, delete all values, and then re-insert only the 
non-expired
+   * values from the iterator. If a get is issued after the delete happens but 
before the
+   * re-insertion completes, the get could return null even when the value 
does actually
+   * exist.)
    *
    * @return number of values cleaned up.
    */
-  def clearExpiredState(): Long
+  def clearExpiredStateForAllKeys(): Long
+
+  /**
+   * When a user calls clear() on a stateful variable, this method is invoked 
to
+   * clear all of the state for the current (implicit) grouping key. It is 
responsible
+   * for deleting from the primary index as well as any secondary index(es).
+   *
+   * If a given state variable has to clean up multiple elementKeys (in 
MapState, for
+   * example, every key in the map is its own elementKey), then this method 
should
+   * be invoked for each of those keys.
+   */
+  def clearAllStateForElementKey(elementKey: UnsafeRow): Unit
 }
 
 /**
- * Manages the ttl information for user state keyed with a single key 
(grouping key).
+ * [[OneToManyTTLState]] is an implementation of [[TTLState]] for stateful 
variables
+ * that associate a single key with multiple values; every value has its own 
expiration
+ * timestamp.
+ *
+ * We need an efficient way to find all the values that have expired, but we 
cannot
+ * issue point-wise deletes to the elements, since they are merged together 
using the
+ * RocksDB StringAppendOperator for merging. As such, we cannot keep a 
secondary index
+ * on the key (expirationMs, groupingKey, indexInList), since we have no way 
to delete a
+ * specific indexInList from the RocksDB value. (In the future, we could write 
a custom
+ * merge operator that can handle tombstones for deleted indexes, but RocksDB 
doesn't
+ * support custom merge operators written in Java/Scala.)
+ *
+ * Instead, we manage expiration per grouping key instead. Our secondary index 
will look
+ * like (expirationMs, groupingKey) -> EMPTY_ROW. This way, we can quickly 
find all the
+ * grouping keys that contain at least one element that has expired.
+ *
+ * There is some trickiness here, though. Suppose we have an element key `k` 
that
+ * has a list with one value `v1` that expires at time `t1`. Our primary index 
looks like
+ * k -> [v1]; our secondary index looks like [(t1, k) -> EMPTY_ROW]. Now, we 
add another
+ * value to the list, `v2`, that expires at time `t2`. The primary index 
updates to be
+ * k -> [v1, v2]. However, how do we update our secondary index? We already 
have an entry
+ * in our secondary index for `k`, but it's prefixed with `t1`, which we don't 
know at the
+ * time of inserting `v2`.
+ *
+ * So, do we:
+ *    1. Blindly add (t2, k) -> EMPTY_ROW to the secondary index?
+ *    2. Delete (t1, k) from the secondary index, and then add (t2, k) -> 
EMPTY_ROW?
+ *
+ * We prefer option 2 because it avoids us from having many entries in the 
secondary
+ * index for the same grouping key. But when we have (t2, k), how do we know 
to delete
+ * (t1, k)? How do we know that t1 is prefixing k in the secondary index?
+ *
+ * To solve this, we introduce another index that maps element key to the the
+ * minimum expiry timestamp for that element key. It is called the min-expiry 
index.
+ * With it, we can know how to update our secondary index, and we can still 
range-scan
+ * the TTL index to find all the keys that have expired.
+ *
+ * In our previous example, we'd use this min-expiry index to tell us whether 
`t2` was
+ * smaller than the minimum expiration on-file for `k`. If it were smaller, 
we'd update
+ * the TTL index, and the min-expiry index. If not, then we'd just update the 
primary
+ * index. When the batchTimestampMs exceeded `t1`, we'd know to clean up `k` 
since it would
+ * contain at least one expired value. When iterating through the many values 
for this `k`,
+ * we'd then be able to find the next minimum value to insert back into the 
secondary and
+ * min-expiry index.
+ *
+ * All of this logic is implemented by the updatePrimaryAndSecondaryIndices 
function.
  */
-abstract class SingleKeyTTLStateImpl(
-    stateName: String,
-    store: StateStore,
-    keyExprEnc: ExpressionEncoder[Any],
-    ttlExpirationMs: Long)
-  extends TTLState {
-
-  import org.apache.spark.sql.execution.streaming.StateTTLSchema._
-
-  private val ttlColumnFamilyName = "$ttl_" + stateName
-  private val keySchema = getSingleKeyTTLRowSchema(keyExprEnc.schema)
-  private val keyTTLRowEncoder = new SingleKeyTTLEncoder(keyExprEnc)
+abstract class OneToManyTTLState(
+    stateNameArg: String,
+    storeArg: StateStore,
+    elementKeySchemaArg: StructType,
+    ttlConfigArg: TTLConfig,
+    batchTimestampMsArg: Long,
+    metricsArg: Map[String, SQLMetric]) extends TTLState {
+  override def stateName: String = stateNameArg
+  override def store: StateStore = storeArg
+  override def elementKeySchema: StructType = elementKeySchemaArg
+  override def ttlConfig: TTLConfig = ttlConfigArg
+  override def batchTimestampMs: Long = batchTimestampMsArg
+  override def metrics: Map[String, SQLMetric] = metricsArg
+
+  // Schema of the min index: elementKey -> minExpirationMs
+  private val MIN_INDEX = "$min_" + stateName
+  private val MIN_INDEX_SCHEMA = elementKeySchema
+  private val MIN_INDEX_VALUE_SCHEMA = getExpirationMsRowSchema()
+
+  // Projects a Long into an UnsafeRow
+  private val minIndexValueProjector = 
UnsafeProjection.create(MIN_INDEX_VALUE_SCHEMA)
+
+  // Schema of the entry count index: elementKey -> count
+  private val COUNT_INDEX = "$count_" + stateName
+  private val COUNT_INDEX_VALUE_SCHEMA: StructType =
+    StructType(Seq(StructField("count", LongType, nullable = false)))
+  private val countIndexValueProjector = 
UnsafeProjection.create(COUNT_INDEX_VALUE_SCHEMA)
+
+  // Reused internal row that we use to create an UnsafeRow with the schema of
+  // COUNT_INDEX_VALUE_SCHEMA and the desired value. It is not thread safe 
(although, anyway,
+  // this class is not thread safe).
+  private val reusedCountIndexValueRow = new GenericInternalRow(1)
+
+  store.createColFamilyIfAbsent(
+    MIN_INDEX,
+    MIN_INDEX_SCHEMA,
+    MIN_INDEX_VALUE_SCHEMA,
+    NoPrefixKeyStateEncoderSpec(MIN_INDEX_SCHEMA),
+    isInternal = true
+  )
 
-  // empty row used for values
-  private val EMPTY_ROW =
-    
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+  store.createColFamilyIfAbsent(
+    COUNT_INDEX,
+    elementKeySchema,
+    COUNT_INDEX_VALUE_SCHEMA,
+    NoPrefixKeyStateEncoderSpec(elementKeySchema),
+    isInternal = true
+  )
 
-  store.createColFamilyIfAbsent(ttlColumnFamilyName, keySchema, 
TTL_VALUE_ROW_SCHEMA,
-    RangeKeyScanStateEncoderSpec(keySchema, Seq(0)), isInternal = true)
 
   /**
-   * This function will be called when clear() on State Variables
-   * with ttl enabled is called. This function should clear any
-   * associated ttlState, since we are clearing the user state.
+   * Function to get the number of entries in the list state for a given 
grouping key
+   * @param encodedKey - encoded grouping key
+   * @return - number of entries in the list state
    */
-  def clearTTLState(): Unit = {
-    val iterator = store.iterator(ttlColumnFamilyName)
-    iterator.foreach { kv =>
-      store.remove(kv.key, ttlColumnFamilyName)
+  def getEntryCount(elementKey: UnsafeRow): Long = {
+    val countRow = store.get(elementKey, COUNT_INDEX)
+    if (countRow != null) {
+      countRow.getLong(0)
+    } else {
+      0L
     }
   }
 
-  def upsertTTLForStateKey(
-      expirationMs: Long,
-      groupingKey: UnsafeRow): Unit = {
-    val encodedTtlKey = keyTTLRowEncoder.encodeTTLRow(
-      expirationMs, groupingKey)
-    store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName)
+  /**
+   * Function to update the number of entries in the list state for a given 
element key
+   * @param elementKey - encoded grouping key
+   * @param updatedCount - updated count of entries in the list state
+   */
+  def updateEntryCount(elementKey: UnsafeRow, updatedCount: Long): Unit = {
+    reusedCountIndexValueRow.setLong(0, updatedCount)
+    store.put(elementKey,
+      
countIndexValueProjector(reusedCountIndexValueRow.asInstanceOf[InternalRow]),
+      COUNT_INDEX
+    )
   }
 
   /**
-   * Clears any state which has ttl older than [[ttlExpirationMs]].
+   * Function to remove the number of entries in the list state for a given 
grouping key
+   * @param elementKey - encoded element key
    */
-  override def clearExpiredState(): Long = {
-    val iterator = store.iterator(ttlColumnFamilyName)
-    var numValuesExpired = 0L
+  def removeEntryCount(elementKey: UnsafeRow): Unit = {
+    store.remove(elementKey, COUNT_INDEX)
+  }
 
-    iterator.takeWhile { kv =>
-      val expirationMs = kv.key.getLong(0)
-      StateTTL.isExpired(expirationMs, ttlExpirationMs)
-    }.foreach { kv =>
-      val groupingKey = kv.key.getStruct(1, keyExprEnc.schema.length)
-      numValuesExpired += clearIfExpired(groupingKey)
-      store.remove(kv.key, ttlColumnFamilyName)
+  private def writePrimaryIndexEntries(
+      overwritePrimaryIndex: Boolean,
+      elementKey: UnsafeRow,
+      elementValues: Iterator[UnsafeRow]): Unit = {
+    val initialEntryCount = if (overwritePrimaryIndex) {
+      removeEntryCount(elementKey)
+      0
+    } else {
+      getEntryCount(elementKey)
     }
-    numValuesExpired
-  }
 
-  private[sql] def ttlIndexIterator(): Iterator[SingleKeyTTLRow] = {
-    val ttlIterator = store.iterator(ttlColumnFamilyName)
+    // Manually keep track of the count so that we can update the count index. 
We don't
+    // want to call elementValues.size since that will try to re-read the 
iterator.
+    var numNewElements = 0
+
+    // If we're overwriting the primary index, then we only need to put the 
first value,
+    // and then we can merge the rest.
+    var isFirst = true
+    elementValues.foreach { value =>
+      numNewElements += 1
+      if (isFirst && overwritePrimaryIndex) {
+        isFirst = false
+        store.put(elementKey, value, stateName)
+      } else {
+        store.merge(elementKey, value, stateName)
+      }
+    }
 
-    new Iterator[SingleKeyTTLRow] {
-      override def hasNext: Boolean = ttlIterator.hasNext
+    TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows", 
numNewElements)
+    updateEntryCount(elementKey, initialEntryCount + numNewElements)
+  }
 
-      override def next(): SingleKeyTTLRow = {
-        val kv = ttlIterator.next()
-        SingleKeyTTLRow(
-          expirationMs = kv.key.getLong(0),
-          groupingKey = kv.key.getStruct(1, keyExprEnc.schema.length)
-        )
+  protected def updatePrimaryAndSecondaryIndices(
+      overwritePrimaryIndex: Boolean,
+      elementKey: UnsafeRow,
+      elementValues: Iterator[UnsafeRow],
+      expirationMs: Long): Unit = {
+    val existingMinExpirationUnsafeRow = store.get(elementKey, MIN_INDEX)
+
+    writePrimaryIndexEntries(overwritePrimaryIndex, elementKey, elementValues)
+
+    // If nothing exists in the minimum index, then we need to make sure to 
write
+    // the minimum and the TTL indices. There's nothing to clean-up from the
+    // secondary index, since it's empty.
+    if (existingMinExpirationUnsafeRow == null) {
+      // Insert into the min-expiry and TTL index, in no particular order.
+      store.put(elementKey, minIndexValueProjector(InternalRow(expirationMs)), 
MIN_INDEX)
+      insertIntoTTLIndex(expirationMs, elementKey)
+    } else {
+      val existingMinExpiration = existingMinExpirationUnsafeRow.getLong(0)
+
+      // If we're overwriting the primary index (via a put, not an append), 
then we need
+      // to definitely clear out the secondary index entries. Why? Suppose we 
had expirations
+      // 5, 10, and 15. If we overwrite, then none of those expirations are 
valid anymore.
+      //
+      // If we're not overwriting the primary index, there is still a case 
where we need to
+      // modify the secondary index. This is if the new expiration is less 
than the existing
+      // expiration. In that case, we must delete from the TTL index, and then 
reinsert into
+      // the TTL index, and then overwrite the min index.
+      if (overwritePrimaryIndex || expirationMs < existingMinExpiration) {
+        // We don't actually have to delete from the min index, since we're 
going
+        // to overwrite it on the next line. However, since the TTL index has 
the existing
+        // minimum expiration in it, we need to delete that.
+        deleteFromTTLIndex(existingMinExpiration, elementKey)
+
+        // Insert into the min-expiry and TTL index, in no particular order.
+        store.put(elementKey, 
minIndexValueProjector(InternalRow(expirationMs)), MIN_INDEX)
+        insertIntoTTLIndex(expirationMs, elementKey)
       }
     }
   }
 
-  private[sql] def getValuesInTTLState(groupingKey: UnsafeRow): Iterator[Long] 
= {
-    val ttlIterator = ttlIndexIterator()
-    var nextValue: Option[Long] = None
-
-    new Iterator[Long] {
-      override def hasNext: Boolean = {
-        while (nextValue.isEmpty && ttlIterator.hasNext) {
-          val nextTtlValue = ttlIterator.next()
-          val valueGroupingKey = nextTtlValue.groupingKey
-          if (valueGroupingKey equals groupingKey) {
-            nextValue = Some(nextTtlValue.expirationMs)
-          }
-        }
-        nextValue.isDefined
-      }
+  // The return type of clearExpiredValues. For a one-to-many stateful 
variable, cleanup
+  // must go through all of the values. numValuesExpired represents the number 
of entries
+  // that were removed (for metrics), and newMinExpirationMs is the new 
minimum expiration
+  // for the values remaining in the state variable.
+  case class ValueExpirationResult(
+      numValuesExpired: Long,
+      newMinExpirationMs: Option[Long])
+
+  // Clears all the expired values for the given elementKey.
+  protected def clearExpiredValues(elementKey: UnsafeRow): 
ValueExpirationResult
 
-      override def next(): Long = {
-        val result = nextValue.get
-        nextValue = None
-        result
+  override def clearExpiredStateForAllKeys(): Long = {
+    var totalNumValuesExpired = 0L
+
+    ttlEvictionIterator().foreach { ttlKey =>
+      val ttlRow = toTTLRow(ttlKey)
+      val elementKey = ttlRow.elementKey
+
+      // Delete from TTL index and minimum index
+      deleteFromTTLIndex(ttlKey)
+      store.remove(elementKey, MIN_INDEX)
+
+      // Now, we need the specific implementation to remove all the values 
associated with
+      // elementKey.
+      val valueExpirationResult = clearExpiredValues(elementKey)
+
+      valueExpirationResult.newMinExpirationMs.foreach { newExpirationMs =>
+        // Insert into the min-expiry and TTL index, in no particular order.
+        store.put(elementKey, 
minIndexValueProjector(InternalRow(newExpirationMs)), MIN_INDEX)
+        insertIntoTTLIndex(newExpirationMs, elementKey)
       }
-    }
-  }
 
-  /**
-   * Clears the user state associated with this grouping key
-   * if it has expired. This function is called by Spark to perform
-   * cleanup at the end of transformWithState processing.
-   *
-   * Spark uses a secondary index to determine if the user state for
-   * this grouping key has expired. However, its possible that the user
-   * has updated the TTL and secondary index is out of date. Implementations
-   * must validate that the user State has actually expired before cleanup 
based
-   * on their own State data.
-   *
-   * @param groupingKey grouping key for which cleanup should be performed.
-   *
-   * @return true if the state was cleared, false otherwise.
-   */
-  def clearIfExpired(groupingKey: UnsafeRow): Long
-}
+      // If we have records [foo, bar, baz] and bar and baz are expiring, 
then, the
+      // entryCountBeforeExpirations would be 3. The numValuesExpired would be 
2, and so the
+      // newEntryCount would be 3 - 2 = 1.
+      val entryCountBeforeExpirations = getEntryCount(elementKey)
+      val numValuesExpired = valueExpirationResult.numValuesExpired
+      val newEntryCount = entryCountBeforeExpirations - numValuesExpired
 
-/**
- * Manages the ttl information for user state keyed with a single key 
(grouping key).
- */
-abstract class CompositeKeyTTLStateImpl[K](
-    stateName: String,
-    store: StateStore,
-    keyExprEnc: ExpressionEncoder[Any],
-    userKeyEncoder: ExpressionEncoder[Any],
-    ttlExpirationMs: Long)
-  extends TTLState {
-
-  import org.apache.spark.sql.execution.streaming.StateTTLSchema._
-
-  private val ttlColumnFamilyName = "$ttl_" + stateName
-  private val keySchema = getCompositeKeyTTLRowSchema(
-    keyExprEnc.schema, userKeyEncoder.schema
-  )
+      TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", 
numValuesExpired)
 
-  private val keyRowEncoder = new CompositeKeyTTLEncoder[K](
-    keyExprEnc, userKeyEncoder)
+      if (newEntryCount == 0) {
+        removeEntryCount(elementKey)
+      } else {
+        updateEntryCount(elementKey, newEntryCount)
+      }
 
-  // empty row used for values
-  private val EMPTY_ROW =
-    
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+      totalNumValuesExpired += numValuesExpired
+    }
+
+    totalNumValuesExpired
+  }
+
+  override def clearAllStateForElementKey(elementKey: UnsafeRow): Unit = {
+    val existingMinExpirationUnsafeRow = store.get(elementKey, MIN_INDEX)
+    if (existingMinExpirationUnsafeRow != null) {
+      val existingMinExpiration = existingMinExpirationUnsafeRow.getLong(0)
 
-  store.createColFamilyIfAbsent(ttlColumnFamilyName, keySchema,
-    TTL_VALUE_ROW_SCHEMA, RangeKeyScanStateEncoderSpec(keySchema,
-      Seq(0)), isInternal = true)
+      store.remove(elementKey, stateName)
+      TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", 
getEntryCount(elementKey))
+      removeEntryCount(elementKey)
 
-  def clearTTLState(): Unit = {
-    val iterator = store.iterator(ttlColumnFamilyName)
-    iterator.foreach { kv =>
-      store.remove(kv.key, ttlColumnFamilyName)
+      store.remove(elementKey, MIN_INDEX)
+      deleteFromTTLIndex(existingMinExpiration, elementKey)
     }
   }
 
-  def upsertTTLForStateKey(
-      expirationMs: Long,
-      groupingKey: UnsafeRow,
-      userKey: UnsafeRow): Unit = {
-    val encodedTtlKey = keyRowEncoder.encodeTTLRow(
-      expirationMs, groupingKey, userKey)
-    store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName)
+  // Exposed for testing.
+  private[sql] def minIndexIterator(): Iterator[(UnsafeRow, Long)] = {
+    store
+      .iterator(MIN_INDEX)
+      .map(kv => (kv.key, kv.value.getLong(0)))
   }
+}
+
+
+/**
+ * OneToOneTTLState is an implementation of [[TTLState]] that is used to manage
+ * TTL for state variables that need a single secondary index to efficiently 
manage
+ * records with an expiration.
+ *
+ * The primary index for state variables that can use a [[OneToOneTTLState]] 
have
+ * the form of: [elementKey -> (value, elementExpiration)]. You'll notice 
that, given
+ * a timestamp, it would take linear time to probe the primary index for all 
of its
+ * expired values.
+ *
+ * As a result, this class uses helper methods from [[TTLState]] to maintain 
the secondary
+ * index from [(elementExpiration, elementKey) -> EMPTY_ROW].
+ *
+ * For an explanation of why this structure is not always sufficient (e.g. why 
the class
+ * [[OneToManyTTLState]] is needed), please visit its class-doc comment.
+ */
+abstract class OneToOneTTLState(
+    stateNameArg: String,
+    storeArg: StateStore,
+    elementKeySchemaArg: StructType,
+    ttlConfigArg: TTLConfig,
+    batchTimestampMsArg: Long,
+    metricsArg: Map[String, SQLMetric]) extends TTLState {
+  override def stateName: String = stateNameArg
+  override def store: StateStore = storeArg
+  override def elementKeySchema: StructType = elementKeySchemaArg
+  override def ttlConfig: TTLConfig = ttlConfigArg
+  override def batchTimestampMs: Long = batchTimestampMsArg
+  override def metrics: Map[String, SQLMetric] = metricsArg
 
   /**
-   * Clears any state which has ttl older than [[ttlExpirationMs]].
+   * This method updates the TTL for the given elementKey to be expirationMs,
+   * updating both the primary and secondary indices if needed.
+   *
+   * Note that an elementKey may be the state variable's grouping key, _or_ it
+   * could be a composite key. MapState is an example of a state variable that
+   * has composite keys, which has the structure of the groupingKey followed by
+   * the specific key in the map. This method doesn't need to know what type of
+   * key is being used, though, since in either case, it's just an UnsafeRow.
+   *
+   * @param elementKey the key for which the TTL should be updated, which may
+   *                   either be an UnsafeRow derived from [[SingleKeyTTLRow]]
+   *                   or [[CompositeKeyTTLRow]].
+   * @param expirationMs the new expiration timestamp to use for elementKey.

Review Comment:
   Ok, I will keep the existing behavior of having the trait manage both 
indices. I've made the rename to the method. If we want to do a refactor later 
(that doesn't change the format), then we can.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to