micheal-o commented on code in PR #48401:
URL: https://github.com/apache/spark/pull/48401#discussion_r1823787543


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2204,6 +2204,16 @@ object SQLConf {
       .intConf
       .createWithDefault(3)
 
+  val STREAMING_STATE_STORE_ENCODING_FORMAT =
+    buildConf("spark.sql.streaming.stateStore.encodingFormat")
+      .doc("The encoding format used for stateful operators to store 
information" +
+        "in the state store")
+      .version("4.0.0")
+      .stringConf
+      .checkValue(v => Set("UnsafeRow", "Avro").contains(v),
+        "Valid versions are 'UnsafeRow' and 'Avro'")

Review Comment:
   nit: you mean valid `values`?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -718,8 +732,7 @@ object TransformWithStateExec {
       queryRunId = UUID.randomUUID(),
       operatorId = 0,
       storeVersion = 0,
-      numPartitions = shufflePartitions,
-      stateStoreCkptIds = None

Review Comment:
   Is removing `stateStoreCkptIds` here intentional?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -49,101 +50,187 @@ class ListStateImpl[S](
   override def baseStateName: String = stateName
   override def exprEncSchema: StructType = keyExprEnc.schema
 
-  private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, 
stateName)
+  // If we are using Avro, the avroSerde parameter must be populated
+  // else, we will default to using UnsafeRow.
+  private val usingAvro: Boolean = avroEnc.isDefined
+  private val avroTypesEncoder = new AvroTypesEncoder[S](
+    keyExprEnc, valEncoder, stateName, hasTtl = false, avroEnc)
+  private val unsafeRowTypesEncoder = new UnsafeRowTypesEncoder[S](
+    keyExprEnc, valEncoder, stateName, hasTtl = false)
 
   store.createColFamilyIfAbsent(stateName, keyExprEnc.schema, 
valEncoder.schema,
     NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = 
true)
 
   /** Whether state exists or not. */
-   override def exists(): Boolean = {
-     val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
-     val stateValue = store.get(encodedGroupingKey, stateName)
-     stateValue != null
-   }
-
-   /**
-    * Get the state value if it exists. If the state does not exist in state 
store, an
-    * empty iterator is returned.
-    */
-   override def get(): Iterator[S] = {
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
-     new Iterator[S] {
-       override def hasNext: Boolean = {
-         unsafeRowValuesIterator.hasNext
-       }
-
-       override def next(): S = {
-         val valueUnsafeRow = unsafeRowValuesIterator.next()
-         stateTypesEncoder.decodeValue(valueUnsafeRow)
-       }
-     }
-   }
-
-   /** Update the value of the list. */
-   override def put(newState: Array[S]): Unit = {
-     validateNewState(newState)
-
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     var isFirst = true
-     var entryCount = 0L
-     TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")
-
-     newState.foreach { v =>
-       val encodedValue = stateTypesEncoder.encodeValue(v)
-       if (isFirst) {
-         store.put(encodedKey, encodedValue, stateName)
-         isFirst = false
-       } else {
-         store.merge(encodedKey, encodedValue, stateName)
-       }
-       entryCount += 1
-       TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     }
-     updateEntryCount(encodedKey, entryCount)
-   }
-
-   /** Append an entry to the list. */
-   override def appendValue(newState: S): Unit = {
-     StateStoreErrors.requireNonNullStateValue(newState, stateName)
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     val entryCount = getEntryCount(encodedKey)
-     store.merge(encodedKey,
-         stateTypesEncoder.encodeValue(newState), stateName)
-     TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     updateEntryCount(encodedKey, entryCount + 1)
-   }
-
-   /** Append an entire list to the existing value. */
-   override def appendList(newState: Array[S]): Unit = {
-     validateNewState(newState)
-
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     var entryCount = getEntryCount(encodedKey)
-     newState.foreach { v =>
-       val encodedValue = stateTypesEncoder.encodeValue(v)
-       store.merge(encodedKey, encodedValue, stateName)
-       entryCount += 1
-       TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     }
-     updateEntryCount(encodedKey, entryCount)
-   }
-
-   /** Remove this state. */
-   override def clear(): Unit = {
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     store.remove(encodedKey, stateName)
-     val entryCount = getEntryCount(encodedKey)
-     TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", 
entryCount)
-     removeEntryCount(encodedKey)
-   }
-
-   private def validateNewState(newState: Array[S]): Unit = {
-     StateStoreErrors.requireNonNullStateValue(newState, stateName)
-     StateStoreErrors.requireNonEmptyListStateValue(newState, stateName)
-
-     newState.foreach { v =>
-       StateStoreErrors.requireNonNullStateValue(v, stateName)
-     }
-   }
- }
+  override def exists(): Boolean = {
+    if (usingAvro) {
+      val encodedKey: Array[Byte] = avroTypesEncoder.encodeGroupingKey()
+      store.get(encodedKey, stateName) != null
+    } else {
+      val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
+      store.get(encodedKey, stateName) != null
+    }
+  }
+
+  /**
+   * Get the state value if it exists. If the state does not exist in state 
store, an
+   * empty iterator is returned.
+   */
+  override def get(): Iterator[S] = {
+    if (usingAvro) {
+      getAvro()
+    } else {
+      getUnsafeRow()
+    }
+  }
+
+  private def getAvro(): Iterator[S] = {
+    val encodedKey: Array[Byte] = avroTypesEncoder.encodeGroupingKey()
+    val avroValuesIterator = store.valuesIterator(encodedKey, stateName)
+    new Iterator[S] {
+      override def hasNext: Boolean = {
+        avroValuesIterator.hasNext
+      }
+
+      override def next(): S = {
+        val valueRow = avroValuesIterator.next()
+        avroTypesEncoder.decodeValue(valueRow)
+      }
+    }
+  }
+
+  private def getUnsafeRow(): Iterator[S] = {
+    val encodedKey = unsafeRowTypesEncoder.encodeGroupingKey()
+    val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
+    new Iterator[S] {
+      override def hasNext: Boolean = {
+        unsafeRowValuesIterator.hasNext
+      }
+
+      override def next(): S = {
+        val valueUnsafeRow = unsafeRowValuesIterator.next()
+        unsafeRowTypesEncoder.decodeValue(valueUnsafeRow)
+      }
+    }
+  }
+
+  /** Update the value of the list. */
+  override def put(newState: Array[S]): Unit = {
+    validateNewState(newState)
+
+    if (usingAvro) {
+      putAvro(newState)
+    } else {
+      putUnsafeRow(newState)
+    }
+  }
+
+  private def putAvro(newState: Array[S]): Unit = {
+    val encodedKey: Array[Byte] = avroTypesEncoder.encodeGroupingKey()
+    var isFirst = true
+    var entryCount = 0L
+    TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")
+
+    newState.foreach { v =>
+      val encodedValue = avroTypesEncoder.encodeValue(v)
+      if (isFirst) {
+        store.put(encodedKey, encodedValue, stateName)
+        isFirst = false
+      } else {
+        store.merge(encodedKey, encodedValue, stateName)
+      }
+      entryCount += 1
+      TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
+    }

Review Comment:
   There is `updateEntryCount` in `putUnsafeRow` but not here. Why?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -49,101 +50,187 @@ class ListStateImpl[S](
   override def baseStateName: String = stateName
   override def exprEncSchema: StructType = keyExprEnc.schema
 
-  private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, 
stateName)
+  // If we are using Avro, the avroSerde parameter must be populated
+  // else, we will default to using UnsafeRow.
+  private val usingAvro: Boolean = avroEnc.isDefined
+  private val avroTypesEncoder = new AvroTypesEncoder[S](
+    keyExprEnc, valEncoder, stateName, hasTtl = false, avroEnc)
+  private val unsafeRowTypesEncoder = new UnsafeRowTypesEncoder[S](
+    keyExprEnc, valEncoder, stateName, hasTtl = false)
 
   store.createColFamilyIfAbsent(stateName, keyExprEnc.schema, 
valEncoder.schema,
     NoPrefixKeyStateEncoderSpec(keyExprEnc.schema), useMultipleValuesPerKey = 
true)
 
   /** Whether state exists or not. */
-   override def exists(): Boolean = {
-     val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
-     val stateValue = store.get(encodedGroupingKey, stateName)
-     stateValue != null
-   }
-
-   /**
-    * Get the state value if it exists. If the state does not exist in state 
store, an
-    * empty iterator is returned.
-    */
-   override def get(): Iterator[S] = {
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
-     new Iterator[S] {
-       override def hasNext: Boolean = {
-         unsafeRowValuesIterator.hasNext
-       }
-
-       override def next(): S = {
-         val valueUnsafeRow = unsafeRowValuesIterator.next()
-         stateTypesEncoder.decodeValue(valueUnsafeRow)
-       }
-     }
-   }
-
-   /** Update the value of the list. */
-   override def put(newState: Array[S]): Unit = {
-     validateNewState(newState)
-
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     var isFirst = true
-     var entryCount = 0L
-     TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")
-
-     newState.foreach { v =>
-       val encodedValue = stateTypesEncoder.encodeValue(v)
-       if (isFirst) {
-         store.put(encodedKey, encodedValue, stateName)
-         isFirst = false
-       } else {
-         store.merge(encodedKey, encodedValue, stateName)
-       }
-       entryCount += 1
-       TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     }
-     updateEntryCount(encodedKey, entryCount)
-   }
-
-   /** Append an entry to the list. */
-   override def appendValue(newState: S): Unit = {
-     StateStoreErrors.requireNonNullStateValue(newState, stateName)
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     val entryCount = getEntryCount(encodedKey)
-     store.merge(encodedKey,
-         stateTypesEncoder.encodeValue(newState), stateName)
-     TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     updateEntryCount(encodedKey, entryCount + 1)
-   }
-
-   /** Append an entire list to the existing value. */
-   override def appendList(newState: Array[S]): Unit = {
-     validateNewState(newState)
-
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     var entryCount = getEntryCount(encodedKey)
-     newState.foreach { v =>
-       val encodedValue = stateTypesEncoder.encodeValue(v)
-       store.merge(encodedKey, encodedValue, stateName)
-       entryCount += 1
-       TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
-     }
-     updateEntryCount(encodedKey, entryCount)
-   }
-
-   /** Remove this state. */
-   override def clear(): Unit = {
-     val encodedKey = stateTypesEncoder.encodeGroupingKey()
-     store.remove(encodedKey, stateName)
-     val entryCount = getEntryCount(encodedKey)
-     TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", 
entryCount)
-     removeEntryCount(encodedKey)
-   }
-
-   private def validateNewState(newState: Array[S]): Unit = {
-     StateStoreErrors.requireNonNullStateValue(newState, stateName)
-     StateStoreErrors.requireNonEmptyListStateValue(newState, stateName)
-
-     newState.foreach { v =>
-       StateStoreErrors.requireNonNullStateValue(v, stateName)
-     }
-   }
- }
+  override def exists(): Boolean = {
+    if (usingAvro) {

Review Comment:
   I see a lot of duplicate code in all the functions in this file. Same with 
the ValueState implementation too. There should be a better way to do this 
without duplicating a lot of the code. The encoders both implement 
`StateTypesEncoder` and the only difference is one returns UnsafeRow and the 
other Array[Byte].
   
   Also, the repeated `if (usingAvro)` check doesn't make the code clean. Can 
we look into if we can have a factory method or do something like this: 
   `private val stateTypesEncoder: StateTypesEncoder  = if (usingAvro) new Avro 
else new Unsafe`
   
   And then just use `stateTypesEncoder` in the methods instead of the repeated 
usingAvro check



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala:
##########
@@ -73,11 +88,11 @@ object TransformWithStateKeyValueRowSchemaUtils {
  * @param stateName - name of logical state partition
  * @tparam V - value type
  */
-class StateTypesEncoder[V](

Review Comment:
   nit: fix class comment



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -40,6 +40,7 @@ class ListStateImpl[S](
      stateName: String,
      keyExprEnc: ExpressionEncoder[Any],
      valEncoder: Encoder[S],
+     avroEnc: Option[AvroEncoderSpec],

Review Comment:
   nit: param comment is missing for this
   
   Also, is there a reason you don't want to default it to `None`? So you don't 
have to manually pass in None for all the current usage



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala:
##########
@@ -59,6 +65,15 @@ object TransformWithStateKeyValueRowSchemaUtils {
   }
 }
 
+trait StateTypesEncoder[V, S] {

Review Comment:
   nit: trait comment?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala:
##########
@@ -101,6 +101,24 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     override def valuesIterator(key: UnsafeRow, colFamilyName: String): 
Iterator[UnsafeRow] = {
       throw 
StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", 
"HDFSStateStore")
     }
+
+

Review Comment:
   nit: remove line?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala:
##########
@@ -143,23 +158,124 @@ class StateTypesEncoder[V](
   }
 }
 
-object StateTypesEncoder {
+
+object UnsafeRowTypesEncoder {
   def apply[V](
       keyEncoder: ExpressionEncoder[Any],
       valEncoder: Encoder[V],
       stateName: String,
-      hasTtl: Boolean = false): StateTypesEncoder[V] = {
-    new StateTypesEncoder[V](keyEncoder, valEncoder, stateName, hasTtl)
+      hasTtl: Boolean = false): UnsafeRowTypesEncoder[V] = {
+    new UnsafeRowTypesEncoder[V](keyEncoder, valEncoder, stateName, hasTtl)
+  }
+}
+
+/**
+ * Helper class providing APIs to encode the grouping key, and user provided 
values
+ * to Spark [[UnsafeRow]].
+ *
+ * CAUTION: StateTypesEncoder class instance is *not* thread-safe.
+ * This class reuses the keyProjection and valueProjection for encoding 
grouping
+ * key and state value respectively. As UnsafeProjection is not thread safe, 
this
+ * class is also not thread safe.
+ *

Review Comment:
   nit: fix class comment



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to