HeartSaVioR commented on code in PR #45674:
URL: https://github.com/apache/spark/pull/45674#discussion_r1551021910
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala:
##########
@@ -65,22 +74,48 @@ class StateTypesEncoder[GK, V](
// TODO: validate places that are trying to encode the key and check if we
can eliminate/
// add caching for some of these calls.
def encodeGroupingKey(): UnsafeRow = {
+ val keyRow = keyProjection(InternalRow(serializeGroupingKey()))
+ keyRow
+ }
+
+ /**
+ * Encodes the provided grouping key into Spark UnsafeRow.
+ *
+ * @param groupingKeyBytes serialized grouping key byte array
+ * @return encoded UnsafeRow
+ */
+ def encodeSerializedGroupingKey(groupingKeyBytes: Array[Byte]): UnsafeRow = {
+ val keyRow = keyProjection(InternalRow(groupingKeyBytes))
+ keyRow
+ }
+
+ def serializeGroupingKey(): Array[Byte] = {
val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
if (keyOption.isEmpty) {
throw StateStoreErrors.implicitKeyNotFound(stateName)
}
-
val groupingKey = keyOption.get.asInstanceOf[GK]
- val keyByteArr =
keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes()
- val keyRow = keyProjection(InternalRow(keyByteArr))
- keyRow
+ keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes()
}
+ /**
+ * Encode the specified value in Spark UnsafeRow with no ttl.
+ * The ttl expiration will be set to -1, specifying no TTL.
+ */
def encodeValue(value: V): UnsafeRow = {
Review Comment:
I'm surprised this is ever possible, given the below method signature.
`def encodeValue(value: V, expirationMs: Long = -1): UnsafeRow`
How two methods are not ambiguous? Looks like an edge case of Scala
compiler, otherwise I don't get how this could be accepted in language spec.
Also the code comment is more proper to the latter method. Maybe you missed
to remove the method and updated the method instead?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##########
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import
org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec,
StateStore}
+import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType,
StructField, StructType}
+
+object StateTTLSchema {
+ val TTL_KEY_ROW_SCHEMA: StructType = new StructType()
+ .add("expirationMs", LongType)
+ .add("groupingKey", BinaryType)
+ val TTL_VALUE_ROW_SCHEMA: StructType =
+ StructType(Array(StructField("__dummy__", NullType)))
+}
+
+/**
+ * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]].
+ *
+ * @param groupingKey grouping key for which ttl is set
+ * @param expirationMs expiration time for the grouping key
+ */
+case class SingleKeyTTLRow(
+ groupingKey: Array[Byte],
+ expirationMs: Long)
+
+/**
+ * Represents the underlying state for secondary TTL Index for a user defined
+ * state variable.
+ *
+ * This state allows Spark to query ttl values based on expiration time
+ * allowing efficient ttl cleanup.
+ */
+trait TTLState {
+
+ /**
+ * Perform the user state clean up based on ttl values stored in
+ * this state. NOTE that its not safe to call this operation concurrently
+ * when the user can also modify the underlying State. Cleanup should be
initiated
+ * after arbitrary state operations are completed by the user.
+ */
+ def clearExpiredState(): Unit
+
+ /**
+ * Clears the user state associated with this grouping key
+ * if it has expired. This function is called by Spark to perform
+ * cleanup at the end of transformWithState processing.
+ *
+ * Spark uses a secondary index to determine if the user state for
+ * this grouping key has expired. However, its possible that the user
+ * has updated the TTL and secondary index is out of date. Implementations
Review Comment:
Do we anticipate a possible bug, or this is expected, e.g. we don't remove
the old entry
of secondary index but just add the new entry when the value is updated?
Yet to read the remaining part of the code.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala:
##########
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{TTLMode, ValueState}
+
+/**
+ * Class that provides a concrete implementation for a single value state
associated with state
+ * variables (with ttl expiration support) used in the streaming
transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing
state
+ * @param stateName - name of logical state partition
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
+ * @param ttlMode - TTL Mode for values stored in this state
+ * @param batchTtlExpirationMs - ttl expiration for the current batch.
+ * @tparam S - data type of object that will be stored
+ */
+class ValueStateImplWithTTL[S](
+ store: StateStore,
+ stateName: String,
+ keyExprEnc: ExpressionEncoder[Any],
+ valEncoder: Encoder[S],
+ ttlMode: TTLMode,
+ batchTtlExpirationMs: Long)
+ extends SingleKeyTTLStateImpl(stateName, store, batchTtlExpirationMs) with
ValueState[S] {
+
+ private val keySerializer = keyExprEnc.createSerializer()
+ private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder,
+ stateName, hasTtl = true)
+
+ initialize()
+
+ private def initialize(): Unit = {
+ assert(ttlMode != TTLMode.NoTTL())
+
+ store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL,
+ NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
+ }
+
+ /** Function to check if state exists. Returns true if present and false
otherwise */
+ override def exists(): Boolean = {
+ get() != null
+ }
+
+ /** Function to return Option of value if exists and None otherwise */
+ override def getOption(): Option[S] = {
+ Option(get())
+ }
+
+ /** Function to return associated value with key if exists and null
otherwise */
+ override def get(): S = {
+ val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ val resState = stateTypesEncoder.decodeValue(retRow)
+
+ if (!isExpired(retRow)) {
+ resState
+ } else {
+ null.asInstanceOf[S]
+ }
+ } else {
+ null.asInstanceOf[S]
+ }
+ }
+
+ /** Function to update and overwrite state associated with given key */
+ override def update(
+ newState: S,
+ ttlDuration: Duration = Duration.ZERO): Unit = {
+
+ if (ttlMode == TTLMode.EventTimeTTL() && ttlDuration != Duration.ZERO) {
+ throw
StateStoreErrors.cannotProvideTTLDurationForEventTimeTTLMode("update",
stateName)
+ }
+
+ if (ttlDuration != null && ttlDuration.isNegative) {
+ throw StateStoreErrors.ttlCannotBeNegative("update", stateName)
+ }
+
+ val expirationTimeInMs =
+ if (ttlDuration != null && ttlDuration != Duration.ZERO) {
+ StateTTL.calculateExpirationTimeForDuration(ttlDuration,
batchTtlExpirationMs)
+ } else {
+ -1
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ override def update(
+ newState: S,
+ expirationTimeInMs: Long): Unit = {
+
+ if (expirationTimeInMs < 0) {
+ throw StateStoreErrors.ttlCannotBeNegative(
+ "update", stateName)
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ private def doUpdate(newState: S,
+ expirationTimeInMs: Long): Unit = {
+ val encodedValue = stateTypesEncoder.encodeValue(newState,
expirationTimeInMs)
+
+ val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+
store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey),
+ encodedValue, stateName)
+
+ if (expirationTimeInMs != -1) {
+ upsertTTLForStateKey(expirationTimeInMs, serializedGroupingKey)
+ }
+ }
+
+ /** Function to remove state for given key */
+ override def clear(): Unit = {
+ store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
+ }
+
+ def clearIfExpired(groupingKey: Array[Byte]): Unit = {
+ val encodedGroupingKey =
stateTypesEncoder.encodeSerializedGroupingKey(groupingKey)
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ if (isExpired(retRow)) {
+ store.remove(encodedGroupingKey, stateName)
+ }
+ }
+ }
+
+ private def isExpired(valueRow: UnsafeRow): Boolean = {
+ val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow)
+ val isExpired = expirationMs.map(
+ StateTTL.isExpired(_, batchTtlExpirationMs))
+
+ isExpired.isDefined && isExpired.get
+ }
+
+ /*
+ * Internal methods to probe state for testing. The below methods exist for
unit tests
+ * to read the state ttl values, and ensure that values are persisted
correctly in
+ * the underlying state store.
Review Comment:
nit: 2 spaces are used
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala:
##########
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{TTLMode, ValueState}
+
+/**
+ * Class that provides a concrete implementation for a single value state
associated with state
+ * variables (with ttl expiration support) used in the streaming
transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing
state
+ * @param stateName - name of logical state partition
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
+ * @param ttlMode - TTL Mode for values stored in this state
+ * @param batchTtlExpirationMs - ttl expiration for the current batch.
+ * @tparam S - data type of object that will be stored
+ */
+class ValueStateImplWithTTL[S](
+ store: StateStore,
+ stateName: String,
+ keyExprEnc: ExpressionEncoder[Any],
+ valEncoder: Encoder[S],
+ ttlMode: TTLMode,
+ batchTtlExpirationMs: Long)
+ extends SingleKeyTTLStateImpl(stateName, store, batchTtlExpirationMs) with
ValueState[S] {
+
+ private val keySerializer = keyExprEnc.createSerializer()
+ private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder,
+ stateName, hasTtl = true)
+
+ initialize()
+
+ private def initialize(): Unit = {
+ assert(ttlMode != TTLMode.NoTTL())
+
+ store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL,
+ NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
+ }
+
+ /** Function to check if state exists. Returns true if present and false
otherwise */
+ override def exists(): Boolean = {
+ get() != null
+ }
+
+ /** Function to return Option of value if exists and None otherwise */
+ override def getOption(): Option[S] = {
+ Option(get())
+ }
+
+ /** Function to return associated value with key if exists and null
otherwise */
+ override def get(): S = {
+ val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ val resState = stateTypesEncoder.decodeValue(retRow)
+
+ if (!isExpired(retRow)) {
+ resState
+ } else {
+ null.asInstanceOf[S]
+ }
+ } else {
+ null.asInstanceOf[S]
+ }
+ }
+
+ /** Function to update and overwrite state associated with given key */
+ override def update(
+ newState: S,
+ ttlDuration: Duration = Duration.ZERO): Unit = {
+
+ if (ttlMode == TTLMode.EventTimeTTL() && ttlDuration != Duration.ZERO) {
+ throw
StateStoreErrors.cannotProvideTTLDurationForEventTimeTTLMode("update",
stateName)
+ }
+
+ if (ttlDuration != null && ttlDuration.isNegative) {
+ throw StateStoreErrors.ttlCannotBeNegative("update", stateName)
+ }
+
+ val expirationTimeInMs =
+ if (ttlDuration != null && ttlDuration != Duration.ZERO) {
+ StateTTL.calculateExpirationTimeForDuration(ttlDuration,
batchTtlExpirationMs)
+ } else {
+ -1
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ override def update(
+ newState: S,
+ expirationTimeInMs: Long): Unit = {
+
+ if (expirationTimeInMs < 0) {
+ throw StateStoreErrors.ttlCannotBeNegative(
+ "update", stateName)
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ private def doUpdate(newState: S,
+ expirationTimeInMs: Long): Unit = {
+ val encodedValue = stateTypesEncoder.encodeValue(newState,
expirationTimeInMs)
+
+ val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+
store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey),
+ encodedValue, stateName)
+
+ if (expirationTimeInMs != -1) {
+ upsertTTLForStateKey(expirationTimeInMs, serializedGroupingKey)
+ }
+ }
+
+ /** Function to remove state for given key */
+ override def clear(): Unit = {
+ store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
+ }
+
+ def clearIfExpired(groupingKey: Array[Byte]): Unit = {
+ val encodedGroupingKey =
stateTypesEncoder.encodeSerializedGroupingKey(groupingKey)
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ if (isExpired(retRow)) {
+ store.remove(encodedGroupingKey, stateName)
+ }
+ }
+ }
+
+ private def isExpired(valueRow: UnsafeRow): Boolean = {
+ val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow)
+ val isExpired = expirationMs.map(
Review Comment:
`expirationMs.filter(StateTTL.isExpired(_, batchTtlExpirationMs)).isDefined`
The above is `true` only when expirationMs is `Some(x)` and
`StateTTL.isExpired(x, batchTtlExpirationMs)` is `true`.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala:
##########
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{TTLMode, ValueState}
+
+/**
+ * Class that provides a concrete implementation for a single value state
associated with state
+ * variables (with ttl expiration support) used in the streaming
transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing
state
+ * @param stateName - name of logical state partition
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
+ * @param ttlMode - TTL Mode for values stored in this state
+ * @param batchTtlExpirationMs - ttl expiration for the current batch.
+ * @tparam S - data type of object that will be stored
+ */
+class ValueStateImplWithTTL[S](
+ store: StateStore,
+ stateName: String,
+ keyExprEnc: ExpressionEncoder[Any],
+ valEncoder: Encoder[S],
+ ttlMode: TTLMode,
+ batchTtlExpirationMs: Long)
+ extends SingleKeyTTLStateImpl(stateName, store, batchTtlExpirationMs) with
ValueState[S] {
+
+ private val keySerializer = keyExprEnc.createSerializer()
+ private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder,
+ stateName, hasTtl = true)
+
+ initialize()
+
+ private def initialize(): Unit = {
+ assert(ttlMode != TTLMode.NoTTL())
+
+ store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL,
+ NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
+ }
+
+ /** Function to check if state exists. Returns true if present and false
otherwise */
+ override def exists(): Boolean = {
+ get() != null
+ }
+
+ /** Function to return Option of value if exists and None otherwise */
+ override def getOption(): Option[S] = {
+ Option(get())
+ }
+
+ /** Function to return associated value with key if exists and null
otherwise */
+ override def get(): S = {
+ val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ val resState = stateTypesEncoder.decodeValue(retRow)
+
+ if (!isExpired(retRow)) {
+ resState
+ } else {
+ null.asInstanceOf[S]
+ }
+ } else {
+ null.asInstanceOf[S]
+ }
+ }
+
+ /** Function to update and overwrite state associated with given key */
+ override def update(
+ newState: S,
+ ttlDuration: Duration = Duration.ZERO): Unit = {
+
+ if (ttlMode == TTLMode.EventTimeTTL() && ttlDuration != Duration.ZERO) {
+ throw
StateStoreErrors.cannotProvideTTLDurationForEventTimeTTLMode("update",
stateName)
+ }
+
+ if (ttlDuration != null && ttlDuration.isNegative) {
+ throw StateStoreErrors.ttlCannotBeNegative("update", stateName)
+ }
+
+ val expirationTimeInMs =
+ if (ttlDuration != null && ttlDuration != Duration.ZERO) {
+ StateTTL.calculateExpirationTimeForDuration(ttlDuration,
batchTtlExpirationMs)
+ } else {
+ -1
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ override def update(
+ newState: S,
+ expirationTimeInMs: Long): Unit = {
+
+ if (expirationTimeInMs < 0) {
+ throw StateStoreErrors.ttlCannotBeNegative(
+ "update", stateName)
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ private def doUpdate(newState: S,
+ expirationTimeInMs: Long): Unit = {
+ val encodedValue = stateTypesEncoder.encodeValue(newState,
expirationTimeInMs)
+
+ val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+
store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey),
+ encodedValue, stateName)
+
+ if (expirationTimeInMs != -1) {
+ upsertTTLForStateKey(expirationTimeInMs, serializedGroupingKey)
+ }
+ }
+
+ /** Function to remove state for given key */
+ override def clear(): Unit = {
+ store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
+ }
+
+ def clearIfExpired(groupingKey: Array[Byte]): Unit = {
+ val encodedGroupingKey =
stateTypesEncoder.encodeSerializedGroupingKey(groupingKey)
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ if (isExpired(retRow)) {
+ store.remove(encodedGroupingKey, stateName)
+ }
+ }
+ }
+
+ private def isExpired(valueRow: UnsafeRow): Boolean = {
+ val expirationMs = stateTypesEncoder.decodeTtlExpirationMs(valueRow)
+ val isExpired = expirationMs.map(
+ StateTTL.isExpired(_, batchTtlExpirationMs))
+
+ isExpired.isDefined && isExpired.get
+ }
+
+ /*
+ * Internal methods to probe state for testing. The below methods exist for
unit tests
+ * to read the state ttl values, and ensure that values are persisted
correctly in
+ * the underlying state store.
+ */
+
+ /**
+ * Retrieves the value from State even if its expired. This method is used
+ * in tests to read the state store value, and ensure if its cleaned up at
the
+ * end of the micro-batch.
+ */
+ private[sql] def getWithoutEnforcingTTL(): Option[S] = {
+ val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ val resState = stateTypesEncoder.decodeValue(retRow)
+ Some(resState)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Read the ttl value associated with the grouping key.
+ */
+ private[sql] def getTTLValue(): Option[Long] = {
+ val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ stateTypesEncoder.decodeTtlExpirationMs(retRow)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Get all ttl values stored in ttl state for current implicit
+ * grouping key.
+ */
+ private[sql] def getValuesInTTLState(): Iterator[Long] = {
+ val ttlIterator = ttlIndexIterator()
+ val implicitGroupingKey = stateTypesEncoder.serializeGroupingKey()
+ var nextValue: Option[Long] = None
+
+ new Iterator[Long] {
Review Comment:
FYI, NextIterator provides this pattern of implementation, though it's
simpler enough so that it's good as it is.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##########
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import
org.apache.spark.sql.execution.streaming.state.{RangeKeyScanStateEncoderSpec,
StateStore}
+import org.apache.spark.sql.types.{BinaryType, DataType, LongType, NullType,
StructField, StructType}
+
+object StateTTLSchema {
+ val TTL_KEY_ROW_SCHEMA: StructType = new StructType()
+ .add("expirationMs", LongType)
+ .add("groupingKey", BinaryType)
+ val TTL_VALUE_ROW_SCHEMA: StructType =
+ StructType(Array(StructField("__dummy__", NullType)))
+}
+
+/**
+ * Encapsulates the ttl row information stored in [[SingleKeyTTLStateImpl]].
+ *
+ * @param groupingKey grouping key for which ttl is set
+ * @param expirationMs expiration time for the grouping key
+ */
+case class SingleKeyTTLRow(
+ groupingKey: Array[Byte],
+ expirationMs: Long)
+
+/**
+ * Represents the underlying state for secondary TTL Index for a user defined
+ * state variable.
+ *
+ * This state allows Spark to query ttl values based on expiration time
+ * allowing efficient ttl cleanup.
+ */
+trait TTLState {
+
+ /**
+ * Perform the user state clean up based on ttl values stored in
+ * this state. NOTE that its not safe to call this operation concurrently
+ * when the user can also modify the underlying State. Cleanup should be
initiated
+ * after arbitrary state operations are completed by the user.
+ */
+ def clearExpiredState(): Unit
+
+ /**
+ * Clears the user state associated with this grouping key
+ * if it has expired. This function is called by Spark to perform
+ * cleanup at the end of transformWithState processing.
+ *
+ * Spark uses a secondary index to determine if the user state for
+ * this grouping key has expired. However, its possible that the user
+ * has updated the TTL and secondary index is out of date. Implementations
Review Comment:
I see the semantic of upsertTTLForStateKey - so this is necessary.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala:
##########
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{TTLMode, ValueState}
+
+/**
+ * Class that provides a concrete implementation for a single value state
associated with state
+ * variables (with ttl expiration support) used in the streaming
transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing
state
+ * @param stateName - name of logical state partition
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
+ * @param ttlMode - TTL Mode for values stored in this state
+ * @param batchTtlExpirationMs - ttl expiration for the current batch.
+ * @tparam S - data type of object that will be stored
+ */
+class ValueStateImplWithTTL[S](
+ store: StateStore,
+ stateName: String,
+ keyExprEnc: ExpressionEncoder[Any],
+ valEncoder: Encoder[S],
+ ttlMode: TTLMode,
+ batchTtlExpirationMs: Long)
+ extends SingleKeyTTLStateImpl(stateName, store, batchTtlExpirationMs) with
ValueState[S] {
+
+ private val keySerializer = keyExprEnc.createSerializer()
+ private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder,
+ stateName, hasTtl = true)
+
+ initialize()
+
+ private def initialize(): Unit = {
+ assert(ttlMode != TTLMode.NoTTL())
+
+ store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL,
+ NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
+ }
+
+ /** Function to check if state exists. Returns true if present and false
otherwise */
+ override def exists(): Boolean = {
+ get() != null
+ }
+
+ /** Function to return Option of value if exists and None otherwise */
+ override def getOption(): Option[S] = {
+ Option(get())
+ }
+
+ /** Function to return associated value with key if exists and null
otherwise */
+ override def get(): S = {
+ val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ val resState = stateTypesEncoder.decodeValue(retRow)
+
+ if (!isExpired(retRow)) {
+ resState
+ } else {
+ null.asInstanceOf[S]
+ }
+ } else {
+ null.asInstanceOf[S]
+ }
+ }
+
+ /** Function to update and overwrite state associated with given key */
+ override def update(
+ newState: S,
+ ttlDuration: Duration = Duration.ZERO): Unit = {
+
+ if (ttlMode == TTLMode.EventTimeTTL() && ttlDuration != Duration.ZERO) {
+ throw
StateStoreErrors.cannotProvideTTLDurationForEventTimeTTLMode("update",
stateName)
+ }
+
+ if (ttlDuration != null && ttlDuration.isNegative) {
+ throw StateStoreErrors.ttlCannotBeNegative("update", stateName)
+ }
+
+ val expirationTimeInMs =
+ if (ttlDuration != null && ttlDuration != Duration.ZERO) {
+ StateTTL.calculateExpirationTimeForDuration(ttlDuration,
batchTtlExpirationMs)
+ } else {
+ -1
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ override def update(
+ newState: S,
+ expirationTimeInMs: Long): Unit = {
+
+ if (expirationTimeInMs < 0) {
+ throw StateStoreErrors.ttlCannotBeNegative(
+ "update", stateName)
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ private def doUpdate(newState: S,
Review Comment:
nit: newState to be next line
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala:
##########
@@ -89,14 +124,29 @@ class StateTypesEncoder[GK, V](
val value = rowToObjDeserializer.apply(reusedValRow)
value
}
+
+ /**
+ * Decode the ttl information out of Value row. If the ttl has
+ * not been set (-1L specifies no user defined value), the API will
+ * return None.
+ */
+ def decodeTtlExpirationMs(row: UnsafeRow): Option[Long] = {
+ val expirationMs = row.getLong(1)
Review Comment:
nit: Maybe check hasTtl for safety guard? I'm fine to assume that the caller
knows the flag and has a responsibility, but if it's not a hot codepath, I'd
feel comfortable to do the check.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala:
##########
@@ -925,15 +926,15 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
hasInitialState, planLater(initialState), planLater(child)
) :: Nil
case logical.TransformWithState(keyDeserializer, valueDeserializer,
groupingAttributes,
- dataAttributes, statefulProcessor, timeoutMode, outputMode,
keyEncoder,
+ dataAttributes, statefulProcessor, ttlMode, timeoutMode, outputMode,
keyEncoder,
outputObjAttr, child, hasInitialState,
initialStateGroupingAttrs, initialStateDataAttrs,
initialStateDeserializer, initialState) =>
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer,
valueDeserializer,
- groupingAttributes, dataAttributes, statefulProcessor, timeoutMode,
outputMode,
+ groupingAttributes, dataAttributes, statefulProcessor, ttlMode,
timeoutMode, outputMode,
keyEncoder, outputObjAttr, planLater(child), hasInitialState,
initialStateGroupingAttrs, initialStateDataAttrs,
- initialStateDeserializer, planLater(initialState)) :: Nil
+ initialStateDeserializer, planLater (initialState)) :: Nil
Review Comment:
nit: unnecessary space?
##########
sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java:
##########
@@ -85,7 +89,7 @@ public scala.collection.Iterator<String> handleInputRows(
}
count += numRows;
- countState.update(count);
+ countState.update(count, Duration.ZERO);
Review Comment:
Probably I had to comment this in API side but I just realized this in here.
I feel like this may be an overkill to allow flexible TTL per every update.
My understanding is that Flink allows users to set the TTL config as a spec of
state, not something users can define for every update. Do we have a specific
use case in mind?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala:
##########
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.time.Duration
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{TTLMode, ValueState}
+
+/**
+ * Class that provides a concrete implementation for a single value state
associated with state
+ * variables (with ttl expiration support) used in the streaming
transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing
state
+ * @param stateName - name of logical state partition
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param valEncoder - Spark SQL encoder for value
+ * @param ttlMode - TTL Mode for values stored in this state
+ * @param batchTtlExpirationMs - ttl expiration for the current batch.
+ * @tparam S - data type of object that will be stored
+ */
+class ValueStateImplWithTTL[S](
+ store: StateStore,
+ stateName: String,
+ keyExprEnc: ExpressionEncoder[Any],
+ valEncoder: Encoder[S],
+ ttlMode: TTLMode,
+ batchTtlExpirationMs: Long)
+ extends SingleKeyTTLStateImpl(stateName, store, batchTtlExpirationMs) with
ValueState[S] {
+
+ private val keySerializer = keyExprEnc.createSerializer()
+ private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder,
+ stateName, hasTtl = true)
+
+ initialize()
+
+ private def initialize(): Unit = {
+ assert(ttlMode != TTLMode.NoTTL())
+
+ store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA_WITH_TTL,
+ NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA))
+ }
+
+ /** Function to check if state exists. Returns true if present and false
otherwise */
+ override def exists(): Boolean = {
+ get() != null
+ }
+
+ /** Function to return Option of value if exists and None otherwise */
+ override def getOption(): Option[S] = {
+ Option(get())
+ }
+
+ /** Function to return associated value with key if exists and null
otherwise */
+ override def get(): S = {
+ val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+ val retRow = store.get(encodedGroupingKey, stateName)
+
+ if (retRow != null) {
+ val resState = stateTypesEncoder.decodeValue(retRow)
+
+ if (!isExpired(retRow)) {
+ resState
+ } else {
+ null.asInstanceOf[S]
+ }
+ } else {
+ null.asInstanceOf[S]
+ }
+ }
+
+ /** Function to update and overwrite state associated with given key */
+ override def update(
+ newState: S,
+ ttlDuration: Duration = Duration.ZERO): Unit = {
+
+ if (ttlMode == TTLMode.EventTimeTTL() && ttlDuration != Duration.ZERO) {
+ throw
StateStoreErrors.cannotProvideTTLDurationForEventTimeTTLMode("update",
stateName)
+ }
+
+ if (ttlDuration != null && ttlDuration.isNegative) {
+ throw StateStoreErrors.ttlCannotBeNegative("update", stateName)
+ }
+
+ val expirationTimeInMs =
+ if (ttlDuration != null && ttlDuration != Duration.ZERO) {
+ StateTTL.calculateExpirationTimeForDuration(ttlDuration,
batchTtlExpirationMs)
+ } else {
+ -1
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ override def update(
+ newState: S,
+ expirationTimeInMs: Long): Unit = {
+
+ if (expirationTimeInMs < 0) {
+ throw StateStoreErrors.ttlCannotBeNegative(
+ "update", stateName)
+ }
+
+ doUpdate(newState, expirationTimeInMs)
+ }
+
+ private def doUpdate(newState: S,
+ expirationTimeInMs: Long): Unit = {
+ val encodedValue = stateTypesEncoder.encodeValue(newState,
expirationTimeInMs)
+
+ val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+
store.put(stateTypesEncoder.encodeSerializedGroupingKey(serializedGroupingKey),
+ encodedValue, stateName)
+
+ if (expirationTimeInMs != -1) {
+ upsertTTLForStateKey(expirationTimeInMs, serializedGroupingKey)
Review Comment:
The method name is actually confusing - it makes me expect that we find the
old value of TTL for this grouping key, and remove it if any, and put the new
TTL. But in reality we seem to leave the old TTL entry as it is. What to upsert
seems to be ambiguious.
I understand we don't have cross reference so uneasy to remove the old
entry, but maybe we could make it clear that the entry is only removed from the
expiration (there's no case where the update of value will "replace" the TTL
entry) so that we shouldn't rely on expiration of TTL entry.
I see the remaining code part is accounting this, but the method name still
feels me in that way. putNewTTLEntryForStateKey to be super clear, or method
doc to clarify?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala:
##########
@@ -78,7 +80,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
testState.update(123)
}
checkError(
- ex.asInstanceOf[SparkException],
+ ex1.asInstanceOf[SparkException],
Review Comment:
Nice finding :)
##########
sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java:
##########
@@ -85,7 +89,7 @@ public scala.collection.Iterator<String> handleInputRows(
}
count += numRows;
- countState.update(count);
+ countState.update(count, Duration.ZERO);
Review Comment:
It looks like regressing the UX on non-TTL case.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]