ericm-db commented on code in PR #48401:
URL: https://github.com/apache/spark/pull/48401#discussion_r1823797894
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -49,101 +50,187 @@ class ListStateImpl[S](
override def baseStateName: String = stateName
override def exprEncSchema: StructType = keyExprEnc.schema
- private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder,
stateName)
+ // If we are using Avro, the avroSerde parameter must be populated
+ // else, we will default to using UnsafeRow.
+ private val usingAvro: Boolean = avroEnc.isDefined
+ private val avroTypesEncoder = new AvroTypesEncoder[S](
+ keyExprEnc, valEncoder, stateName, hasTtl = false, avroEnc)
+ private val unsafeRowTypesEncoder = new UnsafeRowTypesEncoder[S](
+ keyExprEnc, valEncoder, stateName, hasTtl = false)
store.createColFamilyIfAbsent(stateName, keyExprEnc.schema,
valEncoder.schema,
NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey =
true)
/** Whether state exists or not. */
- override def exists(): Boolean = {
- val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
- val stateValue = store.get(encodedGroupingKey, stateName)
- stateValue != null
- }
-
- /**
- * Get the state value if it exists. If the state does not exist in state
store, an
- * empty iterator is returned.
- */
- override def get(): Iterator[S] = {
- val encodedKey = stateTypesEncoder.encodeGroupingKey()
- val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
- new Iterator[S] {
- override def hasNext: Boolean = {
- unsafeRowValuesIterator.hasNext
- }
-
- override def next(): S = {
- val valueUnsafeRow = unsafeRowValuesIterator.next()
- stateTypesEncoder.decodeValue(valueUnsafeRow)
- }
- }
- }
-
- /** Update the value of the list. */
- override def put(newState: Array[S]): Unit = {
- validateNewState(newState)
-
- val encodedKey = stateTypesEncoder.encodeGroupingKey()
- var isFirst = true
- var entryCount = 0L
- TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")
-
- newState.foreach { v =>
- val encodedValue = stateTypesEncoder.encodeValue(v)
- if (isFirst) {
- store.put(encodedKey, encodedValue, stateName)
- isFirst = false
- } else {
- store.merge(encodedKey, encodedValue, stateName)
- }
- entryCount += 1
- TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
- }
- updateEntryCount(encodedKey, entryCount)
- }
-
- /** Append an entry to the list. */
- override def appendValue(newState: S): Unit = {
- StateStoreErrors.requireNonNullStateValue(newState, stateName)
- val encodedKey = stateTypesEncoder.encodeGroupingKey()
- val entryCount = getEntryCount(encodedKey)
- store.merge(encodedKey,
- stateTypesEncoder.encodeValue(newState), stateName)
- TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
- updateEntryCount(encodedKey, entryCount + 1)
- }
-
- /** Append an entire list to the existing value. */
- override def appendList(newState: Array[S]): Unit = {
- validateNewState(newState)
-
- val encodedKey = stateTypesEncoder.encodeGroupingKey()
- var entryCount = getEntryCount(encodedKey)
- newState.foreach { v =>
- val encodedValue = stateTypesEncoder.encodeValue(v)
- store.merge(encodedKey, encodedValue, stateName)
- entryCount += 1
- TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
- }
- updateEntryCount(encodedKey, entryCount)
- }
-
- /** Remove this state. */
- override def clear(): Unit = {
- val encodedKey = stateTypesEncoder.encodeGroupingKey()
- store.remove(encodedKey, stateName)
- val entryCount = getEntryCount(encodedKey)
- TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows",
entryCount)
- removeEntryCount(encodedKey)
- }
-
- private def validateNewState(newState: Array[S]): Unit = {
- StateStoreErrors.requireNonNullStateValue(newState, stateName)
- StateStoreErrors.requireNonEmptyListStateValue(newState, stateName)
-
- newState.foreach { v =>
- StateStoreErrors.requireNonNullStateValue(v, stateName)
- }
- }
- }
+ override def exists(): Boolean = {
+ if (usingAvro) {
Review Comment:
Yeah that's the first thing we've tried. The problem is if we do this,
because `encodeGroupingKey` has different output types based on the encoder you
used, Scala compilation failed when calling `store.get`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]