ericm-db commented on code in PR #48401:
URL: https://github.com/apache/spark/pull/48401#discussion_r1823798102


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -49,101 +50,187 @@ class ListStateImpl[S](
   override def baseStateName: String = stateName
   override def exprEncSchema: StructType = keyExprEnc.schema
 
-  private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, 
stateName)
+  // If we are using Avro, the avroSerde parameter must be populated
+  // else, we will default to using UnsafeRow.
+  private val usingAvro: Boolean = avroEnc.isDefined
+  private val avroTypesEncoder = new AvroTypesEncoder[S](
+    keyExprEnc, valEncoder, stateName, hasTtl = false, avroEnc)
+  private val unsafeRowTypesEncoder = new UnsafeRowTypesEncoder[S](
+    keyExprEnc, valEncoder, stateName, hasTtl = false)
 
   store.createColFamilyIfAbsent(stateName, keyExprEnc.schema, 
valEncoder.schema,
     NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = 
true)
 
   /** Whether state exists or not. */
-   override def exists(): Boolean = {
-     val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
-     val stateValue = store.get(encodedGroupingKey, stateName)
-     stateValue != null
-   }
-
-   /**
-    * Get the state value if it exists. If the state does not exist in state 
store, an
-    * empty iterator is returned.
-    */
-   override def get(): Iterator[S] = {
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
-     new Iterator[S] {
-       override def hasNext: Boolean = {
-         unsafeRowValuesIterator.hasNext
-       }
-
-       override def next(): S = {
-         val valueUnsafeRow = unsafeRowValuesIterator.next()
-         stateTypesEncoder.decodeValue(valueUnsafeRow)
-       }
-     }
-   }
-
-   /** Update the value of the list. */
-   override def put(newState: Array[S]): Unit = {
-     validateNewState(newState)
-
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     var isFirst = true
-     var entryCount = 0L
-     TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")
-
-     newState.foreach { v =>
-       val encodedValue = stateTypesEncoder.encodeValue(v)
-       if (isFirst) {
-         store.put(encodedKey, encodedValue, stateName)
-         isFirst = false
-       } else {
-         store.merge(encodedKey, encodedValue, stateName)
-       }
-       entryCount += 1
-       TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     }
-     updateEntryCount(encodedKey, entryCount)
-   }
-
-   /** Append an entry to the list. */
-   override def appendValue(newState: S): Unit = {
-     StateStoreErrors.requireNonNullStateValue(newState, stateName)
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     val entryCount = getEntryCount(encodedKey)
-     store.merge(encodedKey,
-         stateTypesEncoder.encodeValue(newState), stateName)
-     TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     updateEntryCount(encodedKey, entryCount + 1)
-   }
-
-   /** Append an entire list to the existing value. */
-   override def appendList(newState: Array[S]): Unit = {
-     validateNewState(newState)
-
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     var entryCount = getEntryCount(encodedKey)
-     newState.foreach { v =>
-       val encodedValue = stateTypesEncoder.encodeValue(v)
-       store.merge(encodedKey, encodedValue, stateName)
-       entryCount += 1
-       TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     }
-     updateEntryCount(encodedKey, entryCount)
-   }
-
-   /** Remove this state. */
-   override def clear(): Unit = {
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     store.remove(encodedKey, stateName)
-     val entryCount = getEntryCount(encodedKey)
-     TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", 
entryCount)
-     removeEntryCount(encodedKey)
-   }
-
-   private def validateNewState(newState: Array[S]): Unit = {
-     StateStoreErrors.requireNonNullStateValue(newState, stateName)
-     StateStoreErrors.requireNonEmptyListStateValue(newState, stateName)
-
-     newState.foreach { v =>
-       StateStoreErrors.requireNonNullStateValue(v, stateName)
-     }
-   }
- }
+  override def exists(): Boolean = {
+    if (usingAvro) {
+      val encodedKey: Array[Byte] = avroTypesEncoder.encodeGroupingKey()
+      store.get(encodedKey, stateName) != null
+    } else {
+      val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
+      store.get(encodedKey, stateName) != null
+    }
+  }
+
+  /**
+   * Get the state value if it exists. If the state does not exist in state 
store, an
+   * empty iterator is returned.
+   */
+  override def get(): Iterator[S] = {
+    if (usingAvro) {
+      getAvro()
+    } else {
+      getUnsafeRow()
+    }
+  }
+
+  private def getAvro(): Iterator[S] = {
+    val encodedKey: Array[Byte] = avroTypesEncoder.encodeGroupingKey()
+    val avroValuesIterator = store.valuesIterator(encodedKey, stateName)
+    new Iterator[S] {
+      override def hasNext: Boolean = {
+        avroValuesIterator.hasNext
+      }
+
+      override def next(): S = {
+        val valueRow = avroValuesIterator.next()
+        avroTypesEncoder.decodeValue(valueRow)
+      }
+    }
+  }
+
+  private def getUnsafeRow(): Iterator[S] = {
+    val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
+    val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
+    new Iterator[S] {
+      override def hasNext: Boolean = {
+        unsafeRowValuesIterator.hasNext
+      }
+
+      override def next(): S = {
+        val valueUnsafeRow = unsafeRowValuesIterator.next()
+        unsafeRowTypesEncoder.decodeValue(valueUnsafeRow)
+      }
+    }
+  }
+
+  /** Update the value of the list. */
+  override def put(newState: Array[S]): Unit = {
+    validateNewState(newState)
+
+    if (usingAvro) {
+      putAvro(newState)
+    } else {
+      putUnsafeRow(newState)
+    }
+  }
+
+  private def putAvro(newState: Array[S]): Unit = {
+    val encodedKey: Array[Byte] = avroTypesEncoder.encodeGroupingKey()
+    var isFirst = true
+    var entryCount = 0L
+    TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")
+
+    newState.foreach { v =>
+      val encodedValue = avroTypesEncoder.encodeValue(v)
+      if (isFirst) {
+        store.put(encodedKey, encodedValue, stateName)
+        isFirst = false
+      } else {
+        store.merge(encodedKey, encodedValue, stateName)
+      }
+      entryCount += 1
+      TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
+    }

Review Comment:
   Metrics aren't supported for byte array, will do in a follow-up.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to