neilramaswamy commented on code in PR #48853:
URL: https://github.com/apache/spark/pull/48853#discussion_r1847381368
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##########
@@ -19,274 +19,553 @@ package org.apache.spark.sql.execution.streaming
import java.time.Duration
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
-import
org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec,
StateStore}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
RangeKeyScanStateEncoderSpec, StateStore}
+import org.apache.spark.sql.streaming.TTLConfig
import org.apache.spark.sql.types._
-object StateTTLSchema {
- val TTL_VALUE_ROW_SCHEMA: StructType =
- StructType(Array(StructField("__dummy__", NullType)))
-}
-
/**
- * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]].
+ * Any state variable that wants to support TTL must implement this trait,
+ * which they can do by extending [[OneToOneTTLState]] or
[[OneToManyTTLState]].
*
- * @param groupingKey grouping key for which ttl is set
- * @param expirationMs expiration time for the grouping key
- */
-case class SingleKeyTTLRow(
- groupingKey: UnsafeRow,
- expirationMs: Long)
-
-/**
- * Encapsulates the ttl row information stored in [[CompositeKeyTTLStateImpl]].
+ * The only required methods here are ones relating to evicting expired and all
+ * state, via clearExpiredStateForAllKeys and clearAllStateForElementKey,
+ * respectively. How classes do this is implementation detail, but the general
+ * pattern is to use secondary indexes to make sure cleanup scans
+ * theta(records to evict), not theta(all records).
*
- * @param groupingKey grouping key for which ttl is set
- * @param userKey user key for which ttl is set
- * @param expirationMs expiration time for the grouping key
- */
-case class CompositeKeyTTLRow(
- groupingKey: UnsafeRow,
- userKey: UnsafeRow,
- expirationMs: Long)
-
-/**
- * Represents the underlying state for secondary TTL Index for a user defined
- * state variable.
+ * There are two broad patterns of implementing stateful variables, and thus
+ * there are two broad patterns for implementing TTL. The first is when there
+ * is a one-to-one mapping between an element key [1] and a value; the primary
+ * and secondary index management for this case is implemented by
+ * [[OneToOneTTLState]]. When a single element key can have multiple values,
+ * all of which can expire at their own, unique times, then
+ * [[OneToManyTTLState]] should be used.
+ *
+ * In either case, implementations need to use some sort of secondary index
+ * that orders element keys by expiration time. This base functionality
+ * is provided by methods in this trait that read/write/delete to the
+ * so-called "TTL index". It is a secondary index with the layout of
+ * (expirationMs, elementKey) -> EMPTY_ROW. The expirationMs is big-endian
+ * encoded to allow for efficient range scans to find all expired keys.
+ *
+ * TTLState (or any abstract sub-classes) should never deal with encoding or
+ * decoding UnsafeRows to and from their user-facing types. The stateful
variable
+ * themselves should be doing this; all other TTLState sub-classes should be
concerned
+ * only with writing, reading, and deleting UnsafeRows and their associated
+ * expirations from the primary and secondary indexes. [2]
*
- * This state allows Spark to query ttl values based on expiration time
- * allowing efficient ttl cleanup.
+ * [1]. You might ask, why call it "element key" instead of "grouping key"?
+ * This is because a single grouping key might have multiple elements, as
in
+ * the case of a map, which has composite keys of the form (groupingKey,
mapKey).
+ * In the case of ValueState, though, the element key is the grouping key.
+ * To generalize to both cases, this class should always use the term
elementKey.)
+ *
+ * [2]. You might also ask, why design it this way? We want the TTLState
abstract
+ * sub-classes to write to both the primary and secondary indexes, since
they
+ * both need to stay in sync; co-locating the logic is cleanest.
*/
trait TTLState {
+ // Name of the state variable, e.g. the string the user passes to
get{Value/List/Map}State
+ // in the init() method of a StatefulProcessor.
+ def stateName: String
+
+ // The StateStore instance used to store the state. There is only one
instance shared
+ // among the primary and secondary indexes, since it uses virtual column
families
+ // to keep the indexes separate.
+ def store: StateStore
+
+ // The schema of the primary key for the state variable. For value and list
state, this
+ // is the grouping key. For map state, this is the composite key of the
grouping key and
+ // a map key.
+ def elementKeySchema: StructType
+
+ // The timestamp at which the batch is being processed. All state variables
that have
+ // an expiration at or before this timestamp must be cleaned up.
+ def batchTimestampMs: Long
+
+ // The configuration for this run of the streaming query. It may change
between runs
+ // (e.g. user sets ttlConfig1, stops their query, updates to ttlConfig2, and
then
+ // resumes their query).
+ def ttlConfig: TTLConfig
+
+ // A map from metric name to the underlying SQLMetric. This should not be
updated
+ // by the underlying state variable, as the TTL state implementation should
be
+ // handling all reads/writes/deletes to the indexes.
+ def metrics: Map[String, SQLMetric] = Map.empty
+
+ private final val TTL_INDEX = "$ttl_" + stateName
+ private final val TTL_INDEX_KEY_SCHEMA =
getSingleKeyTTLRowSchema(elementKeySchema)
+ private final val TTL_EMPTY_VALUE_ROW_SCHEMA: StructType =
+ StructType(Array(StructField("__empty__", NullType)))
+
+ private final val TTL_ENCODER = new TTLEncoder(elementKeySchema)
+
+ // Empty row used for values
+ private final val TTL_EMPTY_VALUE_ROW =
+
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+ protected final def ttlExpirationMs = StateTTL
+ .calculateExpirationTimeForDuration(ttlConfig.ttlDuration,
batchTimestampMs)
+
+ store.createColFamilyIfAbsent(
+ TTL_INDEX,
+ TTL_INDEX_KEY_SCHEMA,
+ TTL_EMPTY_VALUE_ROW_SCHEMA,
+ RangeKeyScanStateEncoderSpec(TTL_INDEX_KEY_SCHEMA, Seq(0)),
+ isInternal = true
+ )
+
+ protected def insertIntoTTLIndex(expirationMs: Long, elementKey: UnsafeRow):
Unit = {
+ val secondaryIndexKey = TTL_ENCODER.encodeTTLRow(expirationMs, elementKey)
+ store.put(secondaryIndexKey, TTL_EMPTY_VALUE_ROW, TTL_INDEX)
+ }
+
+ protected def deleteFromTTLIndex(expirationMs: Long, elementKey: UnsafeRow):
Unit = {
+ val secondaryIndexKey = TTL_ENCODER.encodeTTLRow(expirationMs, elementKey)
+ store.remove(secondaryIndexKey, TTL_INDEX)
+ }
+
+ private[sql] def toTTLRow(ttlKey: UnsafeRow): TTLRow = {
+ val expirationMs = ttlKey.getLong(0)
+ val elementKey = ttlKey.getStruct(1, TTL_INDEX_KEY_SCHEMA.length)
+ TTLRow(elementKey, expirationMs)
+ }
+
+ private[sql] def getTTLRows(): Iterator[TTLRow] = {
+ store.iterator(TTL_INDEX).map(kv => toTTLRow(kv.key))
+ }
+
+ protected def ttlEvictionIterator(): Iterator[TTLRow] = {
+ val ttlIterator = store.iterator(TTL_INDEX)
+
+ // Recall that the format is (expirationMs, elementKey) ->
TTL_EMPTY_VALUE_ROW, so
+ // kv.value doesn't ever need to be used.
+ ttlIterator.takeWhile(kv => {
+ val expirationMs = kv.key.getLong(0)
+ StateTTL.isExpired(expirationMs, batchTimestampMs)
+ }).map { kv =>
+ store.remove(kv.key, TTL_INDEX)
+ toTTLRow(kv.key)
+ }
+ }
+
+
+ // Encapsulates a row stored in a TTL index.
+ protected case class TTLRow(elementKey: UnsafeRow, expirationMs: Long)
/**
- * Perform the user state clean up based on ttl values stored in
- * this state. NOTE that its not safe to call this operation concurrently
- * when the user can also modify the underlying State. Cleanup should be
initiated
- * after arbitrary state operations are completed by the user.
+ * Evicts the state associated with this stateful variable that has expired
+ * due to TTL. The eviction applies to all grouping keys, and to all indexes,
+ * primary or secondary.
+ *
+ * This method can be called at any time in the micro-batch execution,
+ * as long as it is allowed to complete before subsequent state operations
are
+ * issued. Operations to the state variable should not be issued
concurrently while
+ * this is running.
+ *
+ * (Why? Some cleanup operations leave the state store in an inconsistent
state while
+ * they are doing cleanup. For example, they may choose to get an iterator
over all
+ * of the existing values, delete all values, and then re-insert only the
non-expired
+ * values from the iterator. If a get is issued after the delete happens but
before the
+ * re-insertion completes, the get could return null even when the value
does actually
+ * exist.)
*
* @return number of values cleaned up.
*/
- def clearExpiredState(): Long
+ def clearExpiredStateForAllKeys(): Long
+
+ /**
+ * Clears all of the state for this state variable associated with the
primary key
+ * elementKey. It is responsible for deleting from the primary index as well
as
+ * any secondary index(es).
+ *
+ * If a given state variable has to clean up multiple elementKeys (in
MapState, for
+ * example, every key in the map is its own elementKey), then this method
should
+ * be invoked for each of those keys.
+ */
+ def clearAllStateForElementKey(elementKey: UnsafeRow): Unit
}
/**
- * Manages the ttl information for user state keyed with a single key
(grouping key).
+ * [[OneToManyTTLState]] is an implementation of [[TTLState]] for stateful
variables
+ * that associate a single key with multiple values; every value has its own
expiration
+ * timestamp.
+ *
+ * We need an efficient way to find all the values that have expired, but we
cannot
+ * issue point-wise deletes to the elements, since they are merged together
using the
+ * RocksDB StringAppendOperator for merging. As such, we cannot keep a
secondary index
+ * on the key (expirationMs, groupingKey, indexInList), since we have no way
to delete a
+ * specific indexInList from the RocksDB value. (In the future, we could write
a custom
+ * merge operator that can handle tombstones for deleted indexes, but RocksDB
doesn't
+ * support custom merge operators written in Java/Scala.)
+ *
+ * Instead, we manage expiration per grouping key instead. Our secondary index
will look
+ * like (expirationMs, groupingKey) -> EMPTY_ROW. This way, we can quickly
find all the
+ * grouping keys that contain at least one element that has expired.
+ *
+ * There is some trickiness here, though. Suppose we have an element key `k`
that
+ * has a list with one value `v1` that expires at time `t1`. Our primary index
looks like
+ * k -> [v1]; our secondary index looks like [(t1, k) -> EMPTY_ROW]. Now, we
add another
+ * value to the list, `v2`, that expires at time `t2`. The primary index
updates to be
+ * k -> [v1, v2]. However, how do we update our secondary index? We already
have an entry
+ * in our secondary index for `k`, but it's prefixed with `t1`, which we don't
know at the
+ * time of inserting `v2`.
+ *
+ * So, do we:
+ * 1. Blindly add (t2, k) -> EMPTY_ROW to the secondary index?
+ * 2. Delete (t1, k) from the secondary index, and then add (t2, k) ->
EMPTY_ROW?
+ *
+ * We prefer option 2 because it avoids us from having many entries in the
secondary
+ * index for the same grouping key. But when we have (t2, k), how do we know
to delete
+ * (t1, k)? How do we know that t1 is prefixing k in the secondary index?
+ *
+ * To solve this, we introduce another index that maps element key to the the
+ * minimum expiry timestamp for that element key. It is called the min-expiry
index.
+ * With it, we can know how to update our secondary index, and we can still
range-scan
+ * the TTL index to find all the keys that have expired.
+ *
+ * In our previous example, we'd use this min-expiry index to tell us whether
`t2` was
+ * smaller than the minimum expiration on-file for `k`. If it were smaller,
we'd update
+ * the TTL index, and the min-expiry index. If not, then we'd just update the
primary
+ * index. When the batchTimestampMs exceeded `t1`, we'd know to clean up `k`
since it would
+ * contain at least one expired value. When iterating through the many values
for this `k`,
+ * we'd then be able to find the next minimum value to insert back into the
secondary and
+ * min-expiry index.
+ *
+ * All of this logic is implemented by updatePrimaryAndSecondaryIndices.
*/
-abstract class SingleKeyTTLStateImpl(
- stateName: String,
- store: StateStore,
- keyExprEnc: ExpressionEncoder[Any],
- ttlExpirationMs: Long)
- extends TTLState {
-
- import org.apache.spark.sql.execution.streaming.StateTTLSchema._
-
- private val ttlColumnFamilyName = "$ttl_" + stateName
- private val keySchema = getSingleKeyTTLRowSchema(keyExprEnc.schema)
- private val keyTTLRowEncoder = new SingleKeyTTLEncoder(keyExprEnc)
+abstract class OneToManyTTLState(
+ stateNameArg: String,
+ storeArg: StateStore,
+ elementKeySchemaArg: StructType,
+ ttlConfigArg: TTLConfig,
+ batchTimestampMsArg: Long,
+ metricsArg: Map[String, SQLMetric]) extends TTLState {
+ override def stateName: String = stateNameArg
+ override def store: StateStore = storeArg
+ override def elementKeySchema: StructType = elementKeySchemaArg
+ override def ttlConfig: TTLConfig = ttlConfigArg
+ override def batchTimestampMs: Long = batchTimestampMsArg
+ override def metrics: Map[String, SQLMetric] = metricsArg
+
+ // Schema of the min index: elementKey -> minExpirationMs
+ private val MIN_INDEX = "$min_" + stateName
+ private val MIN_INDEX_SCHEMA = elementKeySchema
+ private val MIN_INDEX_VALUE_SCHEMA = getExpirationMsRowSchema()
+
+ // Projects a Long into an UnsafeRow
+ private val minIndexValueProjector =
UnsafeProjection.create(MIN_INDEX_VALUE_SCHEMA)
+
+ // Schema of the entry count index: elementKey -> count
+ private val COUNT_INDEX = "$count_" + stateName
+ private val COUNT_INDEX_VALUE_SCHEMA: StructType =
+ StructType(Seq(StructField("count", LongType, nullable = false)))
+ private val countIndexValueProjector =
UnsafeProjection.create(COUNT_INDEX_VALUE_SCHEMA)
+
+ // Reused internal row that we use to create an UnsafeRow with the schema of
+ // COUNT_INDEX_VALUE_SCHEMA and the desired value. It is not thread safe
(although, anyway,
+ // this class is not thread safe).
+ private val reusedCountIndexValueRow = new GenericInternalRow(1)
+
+ store.createColFamilyIfAbsent(
+ MIN_INDEX,
+ MIN_INDEX_SCHEMA,
+ MIN_INDEX_VALUE_SCHEMA,
+ NoPrefixKeyStateEncoderSpec(MIN_INDEX_SCHEMA),
+ isInternal = true
+ )
- // empty row used for values
- private val EMPTY_ROW =
-
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+ store.createColFamilyIfAbsent(
+ COUNT_INDEX,
+ elementKeySchema,
+ COUNT_INDEX_VALUE_SCHEMA,
+ NoPrefixKeyStateEncoderSpec(elementKeySchema),
+ isInternal = true
+ )
- store.createColFamilyIfAbsent(ttlColumnFamilyName, keySchema,
TTL_VALUE_ROW_SCHEMA,
- RangeKeyScanStateEncoderSpec(keySchema, Seq(0)), isInternal = true)
/**
- * This function will be called when clear() on State Variables
- * with ttl enabled is called. This function should clear any
- * associated ttlState, since we are clearing the user state.
+ * Function to get the number of entries in the list state for a given
grouping key
+ * @param encodedKey - encoded grouping key
+ * @return - number of entries in the list state
*/
- def clearTTLState(): Unit = {
- val iterator = store.iterator(ttlColumnFamilyName)
- iterator.foreach { kv =>
- store.remove(kv.key, ttlColumnFamilyName)
+ def getEntryCount(elementKey: UnsafeRow): Long = {
+ val countRow = store.get(elementKey, COUNT_INDEX)
+ if (countRow != null) {
+ countRow.getLong(0)
+ } else {
+ 0L
}
}
- def upsertTTLForStateKey(
- expirationMs: Long,
- groupingKey: UnsafeRow): Unit = {
- val encodedTtlKey = keyTTLRowEncoder.encodeTTLRow(
- expirationMs, groupingKey)
- store.put(encodedTtlKey, EMPTY_ROW, ttlColumnFamilyName)
+ /**
+ * Function to update the number of entries in the list state for a given
element key
+ * @param elementKey - encoded grouping key
+ * @param updatedCount - updated count of entries in the list state
+ */
+ def updateEntryCount(elementKey: UnsafeRow, updatedCount: Long): Unit = {
+ reusedCountIndexValueRow.setLong(0, updatedCount)
+ store.put(elementKey,
+
countIndexValueProjector(reusedCountIndexValueRow.asInstanceOf[InternalRow]),
+ COUNT_INDEX
+ )
}
/**
- * Clears any state which has ttl older than [[ttlExpirationMs]].
+ * Function to remove the number of entries in the list state for a given
grouping key
+ * @param elementKey - encoded element key
*/
- override def clearExpiredState(): Long = {
- val iterator = store.iterator(ttlColumnFamilyName)
- var numValuesExpired = 0L
+ def removeEntryCount(elementKey: UnsafeRow): Unit = {
+ store.remove(elementKey, COUNT_INDEX)
+ }
- iterator.takeWhile { kv =>
- val expirationMs = kv.key.getLong(0)
- StateTTL.isExpired(expirationMs, ttlExpirationMs)
- }.foreach { kv =>
- val groupingKey = kv.key.getStruct(1, keyExprEnc.schema.length)
- numValuesExpired += clearIfExpired(groupingKey)
- store.remove(kv.key, ttlColumnFamilyName)
+ private def writePrimaryIndexEntries(
+ overwritePrimaryIndex: Boolean,
+ elementKey: UnsafeRow,
+ elementValues: Iterator[UnsafeRow]): Unit = {
+ val initialEntryCount = if (overwritePrimaryIndex) {
+ removeEntryCount(elementKey)
+ 0
+ } else {
+ getEntryCount(elementKey)
}
- numValuesExpired
- }
- private[sql] def ttlIndexIterator(): Iterator[SingleKeyTTLRow] = {
- val ttlIterator = store.iterator(ttlColumnFamilyName)
+ // Manually keep track of the count so that we can update the count index.
We don't
+ // want to call elementValues.size since that will try to re-read the
iterator.
+ var numNewElements = 0
+
+ // If we're overwriting the primary index, then we only need to put the
first value,
+ // and then we can merge the rest.
+ var isFirst = true
+ elementValues.foreach { value =>
+ numNewElements += 1
+ if (isFirst && overwritePrimaryIndex) {
+ isFirst = false
+ store.put(elementKey, value, stateName)
+ } else {
+ store.merge(elementKey, value, stateName)
+ }
+ }
- new Iterator[SingleKeyTTLRow] {
- override def hasNext: Boolean = ttlIterator.hasNext
+ TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows",
numNewElements)
+ updateEntryCount(elementKey, initialEntryCount + numNewElements)
+ }
- override def next(): SingleKeyTTLRow = {
- val kv = ttlIterator.next()
- SingleKeyTTLRow(
- expirationMs = kv.key.getLong(0),
- groupingKey = kv.key.getStruct(1, keyExprEnc.schema.length)
- )
+ protected def updatePrimaryAndSecondaryIndices(
+ overwritePrimaryIndex: Boolean,
+ elementKey: UnsafeRow,
+ elementValues: Iterator[UnsafeRow],
+ expirationMs: Long): Unit = {
+ val existingMinExpirationUnsafeRow = store.get(elementKey, MIN_INDEX)
+
+ writePrimaryIndexEntries(overwritePrimaryIndex, elementKey, elementValues)
+
+ // If nothing exists in the secondary index, then we need to make sure to
write
+ // the primary and the secondary indices. There's nothing to clean-up from
the
+ // secondary index, since it's empty.
Review Comment:
comment is wrong, thanks Anish for catching: if nothing exists in the
minimum index, then this key is new. we must write the secondary indices, and
there's nothing to cleanup since it's empty.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]