HeartSaVioR commented on code in PR #48853:
URL: https://github.com/apache/spark/pull/48853#discussion_r1857469364
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##########
@@ -19,274 +19,575 @@ package org.apache.spark.sql.execution.streaming
import java.time.Duration
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
-import
org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec,
StateStore}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
RangeKeyScanStateEncoderSpec, StateStore}
+import org.apache.spark.sql.streaming.TTLConfig
import org.apache.spark.sql.types._
-object StateTTLSchema {
- val TTL_VALUE_ROW_SCHEMA: StructType =
- StructType(Array(StructField("__dummy__", NullType)))
-}
-
/**
- * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]].
+ * Any state variable that wants to support TTL must implement this trait,
+ * which they can do by extending [[OneToOneTTLState]] or
[[OneToManyTTLState]].
*
- * @param groupingKey grouping key for which ttl is set
- * @param expirationMs expiration time for the grouping key
- */
-case class SingleKeyTTLRow(
- groupingKey: UnsafeRow,
- expirationMs: Long)
-
-/**
- * Encapsulates the ttl row information stored in [[CompositeKeyTTLStateImpl]].
+ * The only required methods here are ones relating to evicting expired and all
+ * state, via clearExpiredStateForAllKeys and clearAllStateForElementKey,
+ * respectively. How classes do this is implementation detail, but the general
+ * pattern is to use secondary indexes to make sure cleanup scans
+ * theta(records to evict), not theta(all records).
*
- * @param groupingKey grouping key for which ttl is set
- * @param userKey user key for which ttl is set
- * @param expirationMs expiration time for the grouping key
- */
-case class CompositeKeyTTLRow(
- groupingKey: UnsafeRow,
- userKey: UnsafeRow,
- expirationMs: Long)
-
-/**
- * Represents the underlying state for secondary TTL Index for a user defined
- * state variable.
+ * There are two broad patterns of implementing stateful variables, and thus
+ * there are two broad patterns for implementing TTL. The first is when there
+ * is a one-to-one mapping between an element key [1] and a value; the primary
+ * and secondary index management for this case is implemented by
+ * [[OneToOneTTLState]]. When a single element key can have multiple values,
+ * all of which can expire at their own, unique times, then
+ * [[OneToManyTTLState]] should be used.
+ *
+ * In either case, implementations need to use some sort of secondary index
+ * that orders element keys by expiration time. This base functionality
+ * is provided by methods in this trait that read/write/delete to the
+ * so-called "TTL index". It is a secondary index with the layout of
+ * (expirationMs, elementKey) -> EMPTY_ROW. The expirationMs is big-endian
+ * encoded to allow for efficient range scans to find all expired keys.
+ *
+ * TTLState (or any abstract sub-classes) should never deal with encoding or
+ * decoding UnsafeRows to and from their user-facing types. The stateful
variable
+ * themselves should be doing this; all other TTLState sub-classes should be
concerned
+ * only with writing, reading, and deleting UnsafeRows and their associated
+ * expirations from the primary and secondary indexes. [2]
+ *
+ * [1]. You might ask, why call it "element key" instead of "grouping key"?
+ * This is because a single grouping key might have multiple elements, as
in
+ * the case of a map, which has composite keys of the form (groupingKey,
mapKey).
+ * In the case of ValueState, though, the element key is the grouping key.
+ * To generalize to both cases, this class should always use the term
elementKey.)
*
- * This state allows Spark to query ttl values based on expiration time
- * allowing efficient ttl cleanup.
+ * [2]. You might also ask, why design it this way? We want the TTLState
abstract
+ * sub-classes to write to both the primary and secondary indexes, since
they
+ * both need to stay in sync; co-locating the logic is cleanest.
*/
trait TTLState {
+ // Name of the state variable, e.g. the string the user passes to
get{Value/List/Map}State
+ // in the init() method of a StatefulProcessor.
+ def stateName: String
+
+ // The StateStore instance used to store the state. There is only one
instance shared
+ // among the primary and secondary indexes, since it uses virtual column
families
+ // to keep the indexes separate.
+ def store: StateStore
+
+ // The schema of the primary key for the state variable. For value and list
state, this
+ // is the grouping key. For map state, this is the composite key of the
grouping key and
+ // a map key.
+ def elementKeySchema: StructType
+
+ // The timestamp at which the batch is being processed. All state variables
that have
+ // an expiration at or before this timestamp must be cleaned up.
+ def batchTimestampMs: Long
+
+ // The configuration for this run of the streaming query. It may change
between runs
+ // (e.g. user sets ttlConfig1, stops their query, updates to ttlConfig2, and
then
+ // resumes their query).
+ def ttlConfig: TTLConfig
+
+ // A map from metric name to the underlying SQLMetric. This should not be
updated
+ // by the underlying state variable, as the TTL state implementation should
be
+ // handling all reads/writes/deletes to the indexes.
+ def metrics: Map[String, SQLMetric] = Map.empty
+
+ private final val TTL_INDEX = "$ttl_" + stateName
+ private final val TTL_INDEX_KEY_SCHEMA = getTTLRowKeySchema(elementKeySchema)
+ private final val TTL_EMPTY_VALUE_ROW_SCHEMA: StructType =
+ StructType(Array(StructField("__empty__", NullType)))
+
+ private final val TTL_ENCODER = new TTLEncoder(elementKeySchema)
+
+ // Empty row used for values
+ private final val TTL_EMPTY_VALUE_ROW =
+
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+ protected final def ttlExpirationMs = StateTTL
+ .calculateExpirationTimeForDuration(ttlConfig.ttlDuration,
batchTimestampMs)
+
+ store.createColFamilyIfAbsent(
+ TTL_INDEX,
+ TTL_INDEX_KEY_SCHEMA,
+ TTL_EMPTY_VALUE_ROW_SCHEMA,
+ RangeKeyScanStateEncoderSpec(TTL_INDEX_KEY_SCHEMA, Seq(0)),
+ isInternal = true
+ )
+
+ protected def insertIntoTTLIndex(expirationMs: Long, elementKey: UnsafeRow):
Unit = {
+ val secondaryIndexKey = TTL_ENCODER.encodeTTLRow(expirationMs, elementKey)
+ store.put(secondaryIndexKey, TTL_EMPTY_VALUE_ROW, TTL_INDEX)
+ }
+
+ // The deleteFromTTLIndex overload that takes an expiration time and
elementKey as an
+ // argument is used when we need to _construct_ the key to delete from the
TTL index.
+ //
+ // If we know the timestamp to delete and the elementKey, but don't have a
pre-constructed
+ // UnsafeRow, then you should use this method to delete from the TTL index.
+ protected def deleteFromTTLIndex(expirationMs: Long, elementKey: UnsafeRow):
Unit = {
+ val secondaryIndexKey = TTL_ENCODER.encodeTTLRow(expirationMs, elementKey)
+ store.remove(secondaryIndexKey, TTL_INDEX)
+ }
+
+ // The deleteFromTTLIndex overload that takes an UnsafeRow as an argument is
used when
+ // we're deleting elements from the TTL index that we are iterating over.
+ //
+ // If we were to use the other deleteFromTTLIndex method, we would have to
re-encode the
+ // components into an UnsafeRow. It is more efficient to just pass the
UnsafeRow that we
+ // read from the iterator.
+ protected def deleteFromTTLIndex(ttlKey: UnsafeRow): Unit = {
+ store.remove(ttlKey, TTL_INDEX)
+ }
+
+ private[sql] def toTTLRow(ttlKey: UnsafeRow): TTLRow = {
+ val expirationMs = ttlKey.getLong(0)
+ val elementKey = ttlKey.getStruct(1, TTL_INDEX_KEY_SCHEMA.length)
+ TTLRow(elementKey, expirationMs)
+ }
+
+ private[sql] def getTTLRows(): Iterator[TTLRow] = {
+ store.iterator(TTL_INDEX).map(kv => toTTLRow(kv.key))
+ }
+
+ // Returns an Iterator over all the keys in the TTL index that have expired.
This method
+ // does not delete the keys from the TTL index; it is the responsibility of
the caller
+ // to do so.
+ //
+ // The schema of the UnsafeRow returned by this iterator is (expirationMs,
elementKey).
+ protected def ttlEvictionIterator(): Iterator[UnsafeRow] = {
+ val ttlIterator = store.iterator(TTL_INDEX)
+
+ // Recall that the format is (expirationMs, elementKey) ->
TTL_EMPTY_VALUE_ROW, so
+ // kv.value doesn't ever need to be used.
+ ttlIterator.takeWhile(kv => {
+ val expirationMs = kv.key.getLong(0)
+ StateTTL.isExpired(expirationMs, batchTimestampMs)
+ }).map(_.key)
+ }
+
+
+ // Encapsulates a row stored in a TTL index.
+ protected case class TTLRow(elementKey: UnsafeRow, expirationMs: Long)
/**
- * Perform the user state clean up based on ttl values stored in
- * this state. NOTE that its not safe to call this operation concurrently
- * when the user can also modify the underlying State. Cleanup should be
initiated
- * after arbitrary state operations are completed by the user.
+ * Evicts the state associated with this stateful variable that has expired
+ * due to TTL. The eviction applies to all grouping keys, and to all indexes,
+ * primary or secondary.
+ *
+ * This method can be called at any time in the micro-batch execution,
+ * as long as it is allowed to complete before subsequent state operations
are
+ * issued. Operations to the state variable should not be issued
concurrently while
+ * this is running.
+ *
+ * (Why? Some cleanup operations leave the state store in an inconsistent
state while
+ * they are doing cleanup. For example, they may choose to get an iterator
over all
+ * of the existing values, delete all values, and then re-insert only the
non-expired
+ * values from the iterator. If a get is issued after the delete happens but
before the
+ * re-insertion completes, the get could return null even when the value
does actually
+ * exist.)
*
* @return number of values cleaned up.
*/
- def clearExpiredState(): Long
+ def clearExpiredStateForAllKeys(): Long
+
+ /**
+ * When a user calls clear() on a stateful variable, this method is invoked
to
Review Comment:
Honestly saying (I admit I'm bad at it too), it sounds like actually an
issue of naming methods, not a documentation issue.
I think this trait is making unnecessary confusion because this does beyond
handling TTL; it's not a pure mix-in to attach the TTL ability but takes more
responsibility e.g. removing primary index from clearAllStateForElementKey. You
can move the responsibility to derived class (we do that on updating data) and
only deal with TTL metadata in this trait - after that, the term is not needed
to be `State` which is too overloaded usage.
I don't have strong opinion of whether we should change the code comment;
I've said it's a nit, so just consider this as 2 cents. But I might argue
strongly if this were user facing API.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]